From 83fc696cc7a2e461cc113fa47fcdac29e0b882e4 Mon Sep 17 00:00:00 2001 From: Felix Ableitner Date: Thu, 7 Apr 2016 13:54:36 +0200 Subject: [PATCH] Implemented AODVv2, including integration test (fixes #33). For documentation on how AODVv2 works, see this link: https://datatracker.ietf.org/doc/draft-ietf-manet-aodvv2/ Note that this implementation is incompatible with AODVv2 itself, as various details are changed, and not all features have been implemented --- PROTOCOL.md | 90 ++++++-- README.md | 14 +- android/build.gradle | 17 +- android/proguard-rules.txt | 36 +++ .../bluetooth/BluetoothInterface.scala | 17 +- core/build.gradle | 1 + core/src/main/resources/logback.xml | 2 +- .../com/nutomic/ensichat/core/Address.scala | 5 + .../ensichat/core/ConnectionHandler.scala | 209 ++++++++++++++++-- .../com/nutomic/ensichat/core/Crypto.scala | 20 +- .../com/nutomic/ensichat/core/Message.scala | 12 +- .../com/nutomic/ensichat/core/Router.scala | 62 ++++-- .../ensichat/core/body/EncryptedBody.scala | 2 + .../ensichat/core/body/RouteError.scala | 39 ++++ .../ensichat/core/body/RouteReply.scala | 50 +++++ .../ensichat/core/body/RouteRequest.scala | 44 ++++ .../internet/InternetConnectionThread.scala | 7 +- .../core/internet/InternetInterface.scala | 35 +-- .../core/internet/InternetServerThread.scala | 8 +- .../nutomic/ensichat/core/util/Database.scala | 5 +- .../ensichat/core/util/LocalRoutesInfo.scala | 119 ++++++++++ .../ensichat/core/util/RouteMessageInfo.scala | 74 +++++++ .../nutomic/ensichat/core/CryptoTest.scala | 5 +- .../nutomic/ensichat/core/MessageTest.scala | 7 +- .../nutomic/ensichat/core/RouterTest.scala | 59 +++-- .../ensichat/core/body/RouteErrorTest.scala | 16 ++ .../ensichat/core/body/RouteReplyTest.scala | 16 ++ .../ensichat/core/body/RouteRequestTest.scala | 16 ++ .../core/util/LocalRoutesInfoTest.scala | 45 ++++ .../core/util/RouteMessageInfoTest.scala | 79 +++++++ integration/.gitignore | 1 + integration/build.gradle | 12 + .../LocalNode.scala | 84 +++++++ .../Main.scala | 135 +++++++++++ settings.gradle | 2 +- 35 files changed, 1208 insertions(+), 137 deletions(-) create mode 100644 android/proguard-rules.txt create mode 100644 core/src/main/scala/com/nutomic/ensichat/core/body/RouteError.scala create mode 100644 core/src/main/scala/com/nutomic/ensichat/core/body/RouteReply.scala create mode 100644 core/src/main/scala/com/nutomic/ensichat/core/body/RouteRequest.scala create mode 100644 core/src/main/scala/com/nutomic/ensichat/core/util/LocalRoutesInfo.scala create mode 100644 core/src/main/scala/com/nutomic/ensichat/core/util/RouteMessageInfo.scala create mode 100644 core/src/test/scala/com/nutomic/ensichat/core/body/RouteErrorTest.scala create mode 100644 core/src/test/scala/com/nutomic/ensichat/core/body/RouteReplyTest.scala create mode 100644 core/src/test/scala/com/nutomic/ensichat/core/body/RouteRequestTest.scala create mode 100644 core/src/test/scala/com/nutomic/ensichat/core/util/LocalRoutesInfoTest.scala create mode 100644 core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala create mode 100644 integration/.gitignore create mode 100644 integration/build.gradle create mode 100644 integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala create mode 100644 integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala diff --git a/PROTOCOL.md b/PROTOCOL.md index ab3a6f2..209754c 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -26,6 +26,9 @@ Nodes MUST NOT have a public key with the broadcast address or null address as hash. Additionally, nodes MUST NOT connect to a node with either address. +All integer fields are in network byte order, and unsigned (unless +specified otherwise). + Crypto ------ @@ -40,22 +43,15 @@ private key, and the result written to the 'Encryption Data' part. Routing ------- -A simple flood routing protocol is currently used. Every node forwards -all messages, unless a message with the same Origin and Sequence Number -has already been received. +The routing protocol is based on +[AODVv2](https://datatracker.ietf.org/doc/draft-ietf-manet-aodvv2/), +with various features left out. -Nodes MUST store pairs of (Origin, Sequence Number) for all received -messages. After receiving a new message, entries with the same Origin -and Sequence Number between _received_ + 1 and _received_ + 32767 MUST -be removed (with a wrap around at the maximum value). The entries MUST -NOT be cleared while the program is running. They MAY be cleared when -the program is exited. +TODO: Add Documentation for routing protocol. There is currently no support for offline messages. If sender and receiver are not in the same mesh, the message will not arrive. -Nodes are free implement different routing algorithms. - Messages -------- @@ -84,9 +80,7 @@ AES key is wrapped with the recipient's public RSA key. ### Header Every message starts with one 74 byte header indicating the message -version, type and ID, followed by the length of the message. The -header is in network byte order, i.e. big endian. The header may have -6 bytes of additional data. +version, type and ID, followed by the length of the message. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -220,6 +214,74 @@ After this message has been received, communication with normal messages may start. +### Route Request (Protocol-Type = 2) + +Sent to request a route to a specific Target Address. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Address (32 bytes) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | OrigSeqNum | OriginMetric | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TargMetric | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Equivalent to the Sequence Number in the message header. + +Set OrigMetric = RouterClient.Cost for the Router Client entry +which includes OrigAddr. + +If an Invalid route exists in the Local Route Set matching +TargAddr using longest prefix matching and has a valid +sequence number, set TargSeqNum = LocalRoute.SeqNum. +Otherwise, set TargSeqNum = -1. This field is signed. + +### Route Reply (Protocol-Type = 3) + +Sent as a reply when a Route Request arrives, to inform other nodes +about a route. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TargSeqNum | TargMetric | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Set TargMetric = RouterClient.Cost for the Router Client entry +which includes TargAddr. + +### Route Error (Protocol-Type = 4) + +Notifies other nodes of routes that are no longer available. The target +address MUST be set to the broadcast address. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Packet Source (32 bytes) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | Address (32 bytes) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SeqNum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +Packet Source is the source address of the message triggering this +Route Error. If the route error is not triggered by a message, +this MUST be set to the null address. + +Address is the address that is no longer reachable. + +SeqNum is the sequence number of the route that is no longer available +(if known). Otherwise, set TargSeqNum = -1. This field is signed. + Content Messages ---------------- diff --git a/README.md b/README.md index 6ffb875..2bbddf6 100644 --- a/README.md +++ b/README.md @@ -18,10 +18,20 @@ To setup a development environment, just install [Android Studio](https://develo and import the project. Alternatively, you can use the command line. To create a debug apk, run `./gradlew assembleDevDebug`. -This requires at least Android Lollipop on your development device. If you don't have Lollipop, you -can alternatively use `./gradlew assembleRelDebug`. However, this results in considerably slower +This requires at least Android Lollipop on your development device. If you don't have 5.0 or higher, +you have to use `./gradlew assembleRelDebug`. However, this results in considerably slower incremental builds. To create a release apk, run `./gradlew assembleRelRelease`. +Testing +------- + +You can run the unit tests with `./gradlew test`. After connecting an Android device, you can run +the Android tests with `./gradlew connectedDevDebugAndroidTest` (or +`./gradlew connectedRelDebugAndroidTest` if your Android version is lower than 5.0). + +To run integration tests for the core module, use `./gradlew integration:run`. If this fails (or +is very slow), try changing the value of Crypto#PublicKeySize to 512 (in the core module). + License ------- diff --git a/android/build.gradle b/android/build.gradle index 5dfb5f8..22e783e 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -11,7 +11,7 @@ buildscript { } dependencies { - compile 'com.android.support:design:23.1.1' + compile 'com.android.support:design:23.4.0' compile 'com.android.support:multidex:1.0.1' compile 'org.scala-lang:scala-library:2.11.7' compile 'com.mobsandgeeks:adapter-kit:0.5.3' @@ -46,9 +46,18 @@ android { testInstrumentationRunner "com.android.test.runner.MultiDexTestRunner" } - buildTypes.debug { - applicationIdSuffix ".debug" - testCoverageEnabled true + buildTypes { + debug { + applicationIdSuffix ".debug" + testCoverageEnabled true + } + release { + // HACK: This shouldn't be needed, but multidex isn't working correctly. + // https://code.google.com/p/android/issues/detail?id=206131 + // https://code.google.com/p/android/issues/detail?id=209084 + minifyEnabled true + proguardFiles getDefaultProguardFile('proguard-android.txt'), file('proguard-rules.txt') + } } // Increasing minSdkVersion reduces compilation time for MultiDex. diff --git a/android/proguard-rules.txt b/android/proguard-rules.txt new file mode 100644 index 0000000..450c62d --- /dev/null +++ b/android/proguard-rules.txt @@ -0,0 +1,36 @@ +# Add project specific ProGuard rules here. +# By default, the flags in this file are appended to flags specified +# in /home/sg/adt/sdk/tools/proguard/proguard-android.txt +# You can edit the include path and order by changing the ProGuard +# include property in project.properties. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# Add any project specific keep options here: + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +-dontobfuscate + +-keep class scala.** { *; } +-keep class slick.** { *; } +-keep class org.joda.time.** { *; } +-keep class org.h2.** { *; } +-keep class java.util.** { *; } +-keepclassmembers class java.util.Comparator { + public *; +} + +-dontwarn scala.** +-dontwarn slick.** +-dontwarn org.joda.time.** +-dontwarn org.h2.** +-dontwarn java.util.function.** + +-ignorewarnings diff --git a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala index 46a2937..42d37e1 100644 --- a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala +++ b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala @@ -170,10 +170,11 @@ class BluetoothInterface(context: Context, mainHandler: Handler, * Removes device from active connections. */ def onConnectionClosed(device: Device, socket: BluetoothSocket): Unit = { + val address = getAddressForDevice(device.id) devices -= device.id connections -= device.id - connectionHandler.onConnectionClosed() addressDeviceMap = addressDeviceMap.filterNot(_._2 == device.id) + connectionHandler.onConnectionClosed(address) } /** @@ -192,15 +193,18 @@ class BluetoothInterface(context: Context, mainHandler: Handler, if (!connectionHandler.onConnectionOpened(msg)) addressDeviceMap -= address case _ => - connectionHandler.onMessageReceived(msg) + connectionHandler.onMessageReceived(msg, getAddressForDevice(device)) } + private def getAddressForDevice(device: Device.ID) = + addressDeviceMap.find(_._2 == device).get._1 + /** * Sends the message to nextHop. */ override def send(nextHop: Address, msg: Message): Unit = { addressDeviceMap - .find(_._1 == nextHop) + .find(_._1 == nextHop || Address.Broadcast == nextHop) .map(i => connections.get(i._2)) .getOrElse(None) .foreach(_.send(msg)) @@ -210,11 +214,6 @@ class BluetoothInterface(context: Context, mainHandler: Handler, * Returns all active Bluetooth connections. */ override def getConnections: Set[Address] = - connections.flatMap { x => - addressDeviceMap - .find(_._2 == x._1) - .map(_._1) - } - .toSet + connections.map( c => getAddressForDevice(c._1)).toSet } diff --git a/core/build.gradle b/core/build.gradle index ff6916c..7cb01d1 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -5,6 +5,7 @@ dependencies { compile 'com.h2database:h2:1.4.191' compile 'com.typesafe.slick:slick_2.11:3.1.1' compile 'com.typesafe.scala-logging:scala-logging_2.11:3.4.0' + compile 'joda-time:joda-time:2.9.3' testCompile 'junit:junit:4.12' } diff --git a/core/src/main/resources/logback.xml b/core/src/main/resources/logback.xml index 42b3acb..3087979 100644 --- a/core/src/main/resources/logback.xml +++ b/core/src/main/resources/logback.xml @@ -3,7 +3,7 @@ System.out - %d{HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n + %d{HH:mm:ss} %level/%logger{0}: %msg%n diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Address.scala b/core/src/main/scala/com/nutomic/ensichat/core/Address.scala index cab8638..b71141a 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Address.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Address.scala @@ -56,4 +56,9 @@ final case class Address(bytes: Array[Byte]) { .grouped(Address.GroupLength) .reduce(_ + "-" + _) + /** + * Returns shortened address, useful for debugging. + */ + def short = toString.split("-").head + } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala index 10b2143..dc22815 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala @@ -1,15 +1,18 @@ package com.nutomic.ensichat.core +import java.security.InvalidKeyException import java.util.Date -import com.nutomic.ensichat.core.body.{ConnectionInfo, MessageBody, UserInfo} -import com.nutomic.ensichat.core.header.ContentHeader +import com.nutomic.ensichat.core.body._ +import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader} import com.nutomic.ensichat.core.interfaces._ import com.nutomic.ensichat.core.internet.InternetInterface -import com.nutomic.ensichat.core.util.{Database, FutureHelper} +import com.nutomic.ensichat.core.util.{Database, FutureHelper, LocalRoutesInfo, RouteMessageInfo} import com.typesafe.scalalogging.Logger import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future +import scala.concurrent.duration._ /** * High-level handling of all message transfers and callbacks. @@ -19,16 +22,33 @@ import scala.concurrent.ExecutionContext.Implicits.global */ final class ConnectionHandler(settings: SettingsInterface, database: Database, callbacks: CallbackInterface, crypto: Crypto, - maxInternetConnections: Int) { + maxInternetConnections: Int, + port: Int = InternetInterface.DefaultPort) { private val logger = Logger(this.getClass) + private val MissingRouteMessageTimeout = 5.minutes + private var transmissionInterfaces = Set[TransmissionInterface]() - private lazy val router = new Router(connections, sendVia) - private lazy val seqNumGenerator = new SeqNumGenerator(settings) + private val localRoutesInfo = new LocalRoutesInfo(connections) + + private val routeMessageInfo = new RouteMessageInfo() + + private lazy val router = new Router(localRoutesInfo, + (a, m) => transmissionInterfaces.foreach(_.send(a, m)), + noRouteFound) + + /** + * Contains messages that couldn't be forwarded because we don't know a route. + * + * These will be buffered until we receive a [[RouteReply]] for the target, or when until the + * message has couldn't be forwarded after [[MissingRouteMessageTimeout]]. + */ + private var missingRouteMessages = Set[(Message, Date)]() + /** * Holds all known users. * @@ -42,14 +62,15 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, * @param additionalInterfaces Instances of [[TransmissionInterface]] to transfer data over * platform specific interfaces (eg Bluetooth). */ - def start(additionalInterfaces: Set[TransmissionInterface] = Set()): Unit = { + def start(additionalInterfaces: Set[TransmissionInterface] = Set()): Future[Unit] = { additionalInterfaces.foreach(transmissionInterfaces += _) FutureHelper { crypto.generateLocalKeys() logger.info("Service started, address is " + crypto.localAddress) logger.info("Local user is " + settings.get(SettingsInterface.KeyUserName, "none") + " with status '" + settings.get(SettingsInterface.KeyUserStatus, "") + "'") - transmissionInterfaces += new InternetInterface(this, crypto, settings, maxInternetConnections) + transmissionInterfaces += + new InternetInterface(this, crypto, settings, maxInternetConnections, port) transmissionInterfaces.foreach(_.create()) } } @@ -63,6 +84,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, * Sends a new message to the given target address. */ def sendTo(target: Address, body: MessageBody): Unit = { + assert(body.contentType != -1) FutureHelper { val messageId = settings.get("message_id", 0L) val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(), @@ -76,23 +98,165 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, } } - private def sendVia(nextHop: Address, msg: Message) = - transmissionInterfaces.foreach(_.send(nextHop, msg)) + private def requestRoute(target: Address): Unit = { + assert(localRoutesInfo.getRoute(target).isEmpty) + val seqNum = seqNumGenerator.next() + val targetSeqNum = localRoutesInfo.getRoute(target).map(_.seqNum).getOrElse(-1) + val body = new RouteRequest(target, seqNum, targetSeqNum, 0) + val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum) + + val signed = crypto.sign(new Message(header, body)) + router.forwardMessage(signed) + } + + private def replyRoute(target: Address, replyTo: Address): Unit = { + val seqNum = seqNumGenerator.next() + val body = new RouteReply(seqNum, 0) + val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum) + + val signed = crypto.sign(new Message(header, body)) + router.forwardMessage(signed) + } + + private def routeError(address: Address, packetSource: Option[Address]): Unit = { + val destination = packetSource.getOrElse(Address.Broadcast) + val header = new MessageHeader(RouteError.Type, crypto.localAddress, destination, + seqNumGenerator.next()) + val seqNum = localRoutesInfo.getRoute(address).map(_.seqNum).getOrElse(-1) + val body = new RouteError(address, seqNum) + + val signed = crypto.sign(new Message(header, body)) + router.forwardMessage(signed) + } + + /** + * Force connect to a sepcific internet. + * + * @param address An address in the format IP;port or hostname:port. + */ + def connect(address: String): Unit = { + transmissionInterfaces + .find(_.isInstanceOf[InternetInterface]) + .map(_.asInstanceOf[InternetInterface]) + .foreach(_.openConnection(address)) + } /** * Decrypts and verifies incoming messages, forwards valid ones to [[onNewMessage()]]. */ - def onMessageReceived(msg: Message): Unit = { + def onMessageReceived(msg: Message, previousHop: Address): Unit = { if (router.isMessageSeen(msg)) { logger.trace("Ignoring message from " + msg.header.origin + " that we already received") - } else if (msg.header.target == crypto.localAddress) { - crypto.verifyAndDecrypt(msg) match { - case Some(m) => onNewMessage(m) - case None => logger.info("Ignoring message with invalid signature from " + msg.header.origin) - } - } else { - router.forwardMessage(msg) + return } + + msg.body match { + case rreq: RouteRequest => + localRoutesInfo.addRoute(msg.header.origin, rreq.originSeqNum, previousHop, rreq.originMetric) + // TODO: Respecting this causes the RERR test to fail. We have to fix the implementation + // of isMessageRedundant() without breaking the test. + if (routeMessageInfo.isMessageRedundant(msg)) { + logger.info("Sending redundant RREQ") + //return + } + + if (crypto.localAddress == rreq.requested) + replyRoute(rreq.requested, msg.header.origin) + else { + val body = rreq.copy(originMetric = rreq.originMetric + 1) + + val forwardMsg = crypto.sign(new Message(msg.header, body)) + localRoutesInfo.getRoute(rreq.requested) match { + case Some(route) => router.forwardMessage(forwardMsg, Option(route.nextHop)) + case None => router.forwardMessage(forwardMsg, Option(Address.Broadcast)) + } + } + return + case rrep: RouteReply => + localRoutesInfo.addRoute(msg.header.origin, rrep.originSeqNum, previousHop, 0) + // TODO: See above (in RREQ handler). + if (routeMessageInfo.isMessageRedundant(msg)) { + logger.debug("Sending redundant RREP") + //return + } + + resendMissingRouteMessages() + + if (msg.header.target == crypto.localAddress) + return + + val existingRoute = localRoutesInfo.getRoute(msg.header.target) + val states = Set(LocalRoutesInfo.RouteStates.Active, LocalRoutesInfo.RouteStates.Idle) + if (existingRoute.isEmpty || !states.contains(existingRoute.get.state)) { + routeError(msg.header.target, Option(msg.header.origin)) + return + } + + val body = rrep.copy(originMetric = rrep.originMetric + 1) + + val forwardMsg = crypto.sign(new Message(msg.header, body)) + router.forwardMessage(forwardMsg) + return + case rerr: RouteError => + localRoutesInfo.getRoute(rerr.address).foreach { route => + if (route.nextHop == msg.header.origin && (rerr.seqNum == 0 || rerr.seqNum > route.seqNum)) { + localRoutesInfo.connectionClosed(rerr.address) + .foreach(routeError(_, None)) + } + } + case _ => + } + + if (msg.header.target != crypto.localAddress) { + router.forwardMessage(msg) + return + } + + val plainMsg = + try { + if (!crypto.verify(msg)) { + logger.warn(s"Received message with invalid signature from ${msg.header.origin}") + return + } + + if (msg.header.isContentMessage) + crypto.decrypt(msg) + else + msg + } catch { + case e: InvalidKeyException => + logger.warn(s"Failed to verify or decrypt message $msg", e) + return + } + + onNewMessage(plainMsg) + } + + /** + * Tries to send messages in [[missingRouteMessages]] again, after we acquired a new route. + * + * Before checking [[missingRouteMessages]], those older than [[MissingRouteMessageTimeout]] + * are removed. + */ + private def resendMissingRouteMessages(): Unit = { + // resend messages if possible + val date = new Date() + missingRouteMessages = missingRouteMessages.filter { e => + val removeTime = new Date(e._2.getTime + MissingRouteMessageTimeout.toMillis) + removeTime.after(date) + } + + val m = missingRouteMessages.filter(m => localRoutesInfo.getRoute(m._1.header.target).isDefined) + m.foreach( m => router.forwardMessage(m._1)) + missingRouteMessages --= m + } + + private def noRouteFound(message: Message): Unit = { + if (message.header.origin == crypto.localAddress) { + missingRouteMessages += ((message, new Date())) + requestRoute(message.header.target) + } else + routeError(message.header.target, Option(message.header.origin)) } /** @@ -163,7 +327,11 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, true } - def onConnectionClosed() = callbacks.onConnectionsChanged() + def onConnectionClosed(address: Address): Unit = { + localRoutesInfo.connectionClosed(address) + .foreach(routeError(_, None)) + callbacks.onConnectionsChanged() + } def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections) @@ -177,6 +345,9 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, .find(_.address == address) .getOrElse(new User(address, address.toString(), "")) + /** + * This method should be called when the local device's internet connection has changed in any way. + */ def internetConnectionChanged(): Unit = { transmissionInterfaces .find(_.isInstanceOf[InternetInterface]) diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala b/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala index 2655aa4..a6f5be3 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Crypto.scala @@ -119,7 +119,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) { * @throws RuntimeException If the key does not exist. */ @throws[RuntimeException] - private[core] def getPublicKey(address: Address): PublicKey = { + def getPublicKey(address: Address): PublicKey = { loadKey(address.toString, classOf[PublicKey]) } @@ -129,7 +129,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) { * @throws RuntimeException If a key already exists for this address. */ @throws[RuntimeException] - private[core] def addPublicKey(address: Address, key: PublicKey): Unit = { + def addPublicKey(address: Address, key: PublicKey): Unit = { if (havePublicKey(address)) throw new RuntimeException("Already have key for " + address + ", not overwriting") @@ -232,20 +232,6 @@ class Crypto(settings: SettingsInterface, keyFolder: File) { sign(encrypt(msg, key)) } - private[core] def verifyAndDecrypt(msg: Message, key: Option[PublicKey] = None): Option[Message] = { - // Catch exception to avoid crash if we receive invalid message. - try { - if (verify(msg, key)) - Option(decrypt(msg)) - else - None - } catch { - case e: InvalidKeyException => - logger.warn("Failed to verify or decrypt message", e) - None - } - } - private def encrypt(msg: Message, key: Option[PublicKey] = None): Message = { // Symmetric encryption of data val secretKey = makeSecretKey() @@ -263,7 +249,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) { } @throws[InvalidKeyException] - private def decrypt(msg: Message): Message = { + def decrypt(msg: Message): Message = { // Asymmetric decryption of secret key val asymmetricCipher = Cipher.getInstance(CipherAlgorithm) asymmetricCipher.init(Cipher.UNWRAP_MODE, loadKey(PrivateKeyAlias, classOf[PrivateKey])) diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Message.scala b/core/src/main/scala/com/nutomic/ensichat/core/Message.scala index 1487a29..65eb958 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Message.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Message.scala @@ -3,7 +3,7 @@ package com.nutomic.ensichat.core import java.io.InputStream import java.security.spec.InvalidKeySpecException -import com.nutomic.ensichat.core.body.{ConnectionInfo, CryptoData, EncryptedBody, MessageBody} +import com.nutomic.ensichat.core.body._ import com.nutomic.ensichat.core.header.{AbstractHeader, ContentHeader, MessageHeader} object Message { @@ -50,6 +50,9 @@ object Message { val body = header.protocolType match { case ConnectionInfo.Type => ConnectionInfo.read(remaining) + case RouteRequest.Type => RouteRequest.read(remaining) + case RouteReply.Type => RouteReply.read(remaining) + case RouteError.Type => RouteError.read(remaining) case _ => new EncryptedBody(remaining) } @@ -80,6 +83,11 @@ case class Message(header: AbstractHeader, crypto: CryptoData, body: MessageBody def this(header: AbstractHeader, body: MessageBody) = this(header, new CryptoData(None, None), body) - def write = header.write(body.length + crypto.length) ++ crypto.write ++ body.write + def write = { + header.write(body.length + crypto.length) ++ crypto.write ++ body.write + } + + override def toString = + s"Message(${header.origin.short}(${header.seqNum}) -> ${header.target.short}: $body)" } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Router.scala b/core/src/main/scala/com/nutomic/ensichat/core/Router.scala index a487917..02159b8 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Router.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Router.scala @@ -1,18 +1,44 @@ package com.nutomic.ensichat.core +import java.util.Comparator + import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader} +import com.nutomic.ensichat.core.util.LocalRoutesInfo + +object Router extends Comparator[Int] { + + /** + * Compares which sequence number is newer. + * + * @return 1 if lhs is newer, -1 if rhs is newer, 0 if they are equal. + */ + override def compare(lhs: Int, rhs: Int): Int = { + if (lhs == rhs) + 0 + // True if [[rhs]] is between {{{MessageHeader.SeqNumRange.size / 2}}} and + // [[MessageHeader.SeqNumRange.size]]. + else if (lhs > ContentHeader.SeqNumRange.size / 2) { + // True if [[rhs]] is between {{{lhs - MessageHeader.SeqNumRange.size / 2}}} and [[lhs]]. + if (lhs - ContentHeader.SeqNumRange.size / 2 < rhs && rhs < lhs) 1 else -1 + } else { + // True if [[rhs]] is *not* between [[lhs]] and {{{lhs + MessageHeader.SeqNumRange.size / 2}}}. + if (rhs < lhs || rhs > lhs + ContentHeader.SeqNumRange.size / 2) 1 else -1 + } + } +} /** * Forwards messages to all connected devices. */ -final private[core] class Router(activeConnections: () => Set[Address], send: (Address, Message) => Unit) { +private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message) => Unit, + noRouteFound: (Message) => Unit) { private var messageSeen = Set[(Address, Int)]() /** * Returns true if we have received the same message before. */ - def isMessageSeen(msg: Message): Boolean = { + private[core] def isMessageSeen(msg: Message): Boolean = { val info = (msg.header.origin, msg.header.seqNum) val seen = messageSeen.contains(info) markMessageSeen(info) @@ -23,15 +49,24 @@ final private[core] class Router(activeConnections: () => Set[Address], send: (A * Sends message to all connected devices. Should only be called if [[isMessageSeen()]] returns * true. */ - def forwardMessage(msg: Message): Unit = { - val info = (msg.header.origin, msg.header.seqNum) - val updated = incHopCount(msg) - if (updated.header.hopCount >= updated.header.hopLimit) + def forwardMessage(msg: Message, nextHopOption: Option[Address] = None): Unit = { + if (msg.header.hopCount + 1 >= msg.header.hopLimit) return - activeConnections().foreach(a => send(a, updated)) + val nextHop = nextHopOption.getOrElse(msg.header.target) - markMessageSeen(info) + if (nextHop == Address.Broadcast) { + send(nextHop, msg) + return + } + + routesInfo.getRoute(nextHop).map(_.nextHop) match { + case Some(a) => + send(a, incHopCount(msg)) + markMessageSeen((msg.header.origin, msg.header.seqNum)) + case None => + noRouteFound(msg) + } } private def markMessageSeen(info: (Address, Int)): Unit = { @@ -64,15 +99,8 @@ final private[core] class Router(activeConnections: () => Set[Address], send: (A if (a1 != a2) true - // True if [[s2]] is between {{{MessageHeader.SeqNumRange.size / 2}}} and - // [[MessageHeader.SeqNumRange.size]]. - if (s1 > ContentHeader.SeqNumRange.size / 2) { - // True if [[s2]] is between {{{s1 - MessageHeader.SeqNumRange.size / 2}}} and [[s1]]. - s1 - ContentHeader.SeqNumRange.size / 2 < s2 && s2 < s1 - } else { - // True if [[s2]] is *not* between [[s1]] and {{{s1 + MessageHeader.SeqNumRange.size / 2}}}. - s2 < s1 || s2 > s1 + ContentHeader.SeqNumRange.size / 2 - } + else + Router.compare(s1, s2) > 0 } } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/body/EncryptedBody.scala b/core/src/main/scala/com/nutomic/ensichat/core/body/EncryptedBody.scala index 5c6c18d..49b13fb 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/body/EncryptedBody.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/body/EncryptedBody.scala @@ -12,4 +12,6 @@ final case class EncryptedBody(data: Array[Byte]) extends MessageBody { def write = data override def length = data.length + + override def toString = "EncryptedBody" } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/body/RouteError.scala b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteError.scala new file mode 100644 index 0000000..f996a3e --- /dev/null +++ b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteError.scala @@ -0,0 +1,39 @@ +package com.nutomic.ensichat.core.body + +import java.nio.ByteBuffer + +import com.nutomic.ensichat.core.Address +import com.nutomic.ensichat.core.util.BufferUtils + +private[core] object RouteError { + + val Type = 4 + + /** + * Constructs [[RouteError]] instance from byte array. + */ + def read(array: Array[Byte]): RouteError = { + val b = ByteBuffer.wrap(array) + val address = new Address(BufferUtils.getByteArray(b, Address.Length)) + val seqNum = b.getInt + new RouteError(address, seqNum) + } + +} + +private[core] case class RouteError(address: Address, seqNum: Int) extends MessageBody { + + override def protocolType = RouteReply.Type + + override def contentType = -1 + + override def write: Array[Byte] = { + val b = ByteBuffer.allocate(length) + b.put(address.bytes) + b.putInt(seqNum) + b.array() + } + + override def length = Address.Length + 4 + +} diff --git a/core/src/main/scala/com/nutomic/ensichat/core/body/RouteReply.scala b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteReply.scala new file mode 100644 index 0000000..33c23f0 --- /dev/null +++ b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteReply.scala @@ -0,0 +1,50 @@ +package com.nutomic.ensichat.core.body + +import java.nio.ByteBuffer + +import com.nutomic.ensichat.core.util.BufferUtils + +private[core] object RouteReply { + + val Type = 3 + + /** + * Constructs [[RouteReply]] instance from byte array. + */ + def read(array: Array[Byte]): RouteReply = { + val b = ByteBuffer.wrap(array) + val targSeqNum = BufferUtils.getUnsignedShort(b) + val targMetric = BufferUtils.getUnsignedShort(b) + new RouteReply(targSeqNum, targMetric) + } + +} + +/** + * Sends information about a route. + * + * Note that the fields are named different than described in AODVv2. There, targSeqNum and + * targMetric are used to describe the seqNum and metric of the node sending the route reply. In + * Ensichat, we use originSeqNum and originMetric instead, to stay consistent with the header + * fields. That means header.origin, originSeqNum and originMetric all refer to the node sending + * this message. + * + * @param originSeqNum The current sequence number of the node sending this message. + * @param originMetric The metric of the current route to the sending node. + */ +private[core] case class RouteReply(originSeqNum: Int, originMetric: Int) extends MessageBody { + + override def protocolType = RouteReply.Type + + override def contentType = -1 + + override def write: Array[Byte] = { + val b = ByteBuffer.allocate(length) + BufferUtils.putUnsignedShort(b, originSeqNum) + BufferUtils.putUnsignedShort(b, originMetric) + b.array() + } + + override def length = 4 + +} diff --git a/core/src/main/scala/com/nutomic/ensichat/core/body/RouteRequest.scala b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteRequest.scala new file mode 100644 index 0000000..bc728c3 --- /dev/null +++ b/core/src/main/scala/com/nutomic/ensichat/core/body/RouteRequest.scala @@ -0,0 +1,44 @@ +package com.nutomic.ensichat.core.body + +import java.nio.ByteBuffer + +import com.nutomic.ensichat.core.Address +import com.nutomic.ensichat.core.util.BufferUtils + +private[core] object RouteRequest { + + val Type = 2 + + /** + * Constructs [[RouteRequest]] instance from byte array. + */ + def read(array: Array[Byte]): RouteRequest = { + val b = ByteBuffer.wrap(array) + val requested = new Address(BufferUtils.getByteArray(b, Address.Length)) + val origSeqNum = BufferUtils.getUnsignedShort(b) + val originMetric = BufferUtils.getUnsignedShort(b) + val targSeqNum = b.getInt() + new RouteRequest(requested, origSeqNum, targSeqNum, originMetric) + } + +} + +private[core] case class RouteRequest(requested: Address, originSeqNum: Int, targSeqNum: Int, originMetric: Int) + extends MessageBody { + + override def protocolType = RouteRequest.Type + + override def contentType = -1 + + override def write: Array[Byte] = { + val b = ByteBuffer.allocate(length) + b.put(requested.bytes) + BufferUtils.putUnsignedShort(b, originSeqNum) + BufferUtils.putUnsignedShort(b, originMetric) + b.putInt(targSeqNum) + b.array() + } + + override def length = 8 + Address.Length + +} diff --git a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala index aa36f1f..4235964 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala @@ -12,8 +12,10 @@ import com.typesafe.scalalogging.Logger /** * Encapsulates an active connection to another node. */ -class InternetConnectionThread(socket: Socket, crypto: Crypto, onDisconnected: (InternetConnectionThread) => Unit, - onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { +private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto, + onDisconnected: (InternetConnectionThread) => Unit, + onReceive: (Message, InternetConnectionThread) => Unit) + extends Thread { private val logger = Logger(this.getClass) @@ -78,7 +80,6 @@ class InternetConnectionThread(socket: Socket, crypto: Crypto, onDisconnected: ( } catch { case e: IOException => logger.warn("Failed to close socket", e) } - logger.debug("Connection to " + socket.getInetAddress + " closed") onDisconnected(this) } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala index 3df31a9..f68905d 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala @@ -12,9 +12,9 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future import scala.util.Random -object InternetInterface { +private[core] object InternetInterface { - val ServerPort = 26344 + val DefaultPort = 26344 } @@ -23,14 +23,14 @@ object InternetInterface { * * @param maxConnections Maximum number of concurrent connections that should be opened. */ -class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, - settings: SettingsInterface, maxConnections: Int) +private[core] class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, + settings: SettingsInterface, maxConnections: Int, port: Int) extends TransmissionInterface { private val logger = Logger(this.getClass) private lazy val serverThread = - new InternetServerThread(crypto, onConnected, onDisconnected, onReceiveMessage) + new InternetServerThread(crypto, port, onConnected, onDisconnected, onReceiveMessage) private var connections = Set[InternetConnectionThread]() @@ -44,10 +44,8 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, .replace("46.101.249.188:26344", SettingsInterface.DefaultAddresses) settings.put(SettingsInterface.KeyAddresses, servers) - FutureHelper { - serverThread.start() - openAllConnections(maxConnections) - } + serverThread.start() + openAllConnections(maxConnections) } /** @@ -69,13 +67,13 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, .foreach(openConnection) } - private def openConnection(addressPort: String): Unit = { + def openConnection(addressPort: String): Unit = { val (address, port) = if (addressPort.contains(":")) { val split = addressPort.split(":") (split(0), split(1).toInt) } else - (addressPort, InternetInterface.ServerPort) + (addressPort, InternetInterface.DefaultPort) openConnection(address, port) } @@ -100,11 +98,11 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, } private def onDisconnected(connectionThread: InternetConnectionThread): Unit = { - addressDeviceMap.find(_._2 == connectionThread).foreach { ad => - logger.trace("Connection closed to " + ad._1) + getAddressForThread(connectionThread).foreach { ad => + logger.trace("Connection closed to " + ad) connections -= connectionThread - addressDeviceMap -= ad._1 - connectionHandler.onConnectionClosed() + addressDeviceMap -= ad + connectionHandler.onConnectionClosed(ad) } } @@ -122,15 +120,18 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, if (!connectionHandler.onConnectionOpened(msg)) addressDeviceMap -= address case _ => - connectionHandler.onMessageReceived(msg) + connectionHandler.onMessageReceived(msg, getAddressForThread(thread).get) } + private def getAddressForThread(thread: InternetConnectionThread) = + addressDeviceMap.find(_._2 == thread).map(_._1) + /** * Sends the message to nextHop. */ override def send(nextHop: Address, msg: Message): Unit = { addressDeviceMap - .find(_._1 == nextHop) + .filter(_._1 == nextHop || Address.Broadcast == nextHop) .foreach(_._2.send(msg)) } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetServerThread.scala b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetServerThread.scala index cf507b3..56eb1b1 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetServerThread.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetServerThread.scala @@ -6,13 +6,15 @@ import java.net.ServerSocket import com.nutomic.ensichat.core.{Crypto, Message} import com.typesafe.scalalogging.Logger -class InternetServerThread(crypto: Crypto, onConnected: (InternetConnectionThread) => Unit, - onDisconnected: (InternetConnectionThread) => Unit, onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { +class InternetServerThread(crypto: Crypto, port: Int, + onConnected: (InternetConnectionThread) => Unit, + onDisconnected: (InternetConnectionThread) => Unit, + onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { private val logger = Logger(this.getClass) private lazy val socket: Option[ServerSocket] = try { - Option(new ServerSocket(InternetInterface.ServerPort)) + Option(new ServerSocket(port)) } catch { case e: IOException => logger.warn("Failed to create server socket", e) diff --git a/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala b/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala index 937d57e..ec8d464 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala @@ -75,8 +75,9 @@ class Database(path: File, callbackInterface: CallbackInterface) { /** * Inserts the given new message into the database. */ - def onMessageReceived(msg: Message): Unit = { - Await.result(db.run(messages += msg), Duration.Inf) + def onMessageReceived(msg: Message): Unit = msg.body match { + case _: Text => Await.result(db.run(messages += msg), Duration.Inf) + case _ => } def getMessages(address: Address): Seq[Message] = { diff --git a/core/src/main/scala/com/nutomic/ensichat/core/util/LocalRoutesInfo.scala b/core/src/main/scala/com/nutomic/ensichat/core/util/LocalRoutesInfo.scala new file mode 100644 index 0000000..846c8c7 --- /dev/null +++ b/core/src/main/scala/com/nutomic/ensichat/core/util/LocalRoutesInfo.scala @@ -0,0 +1,119 @@ +package com.nutomic.ensichat.core.util + +import com.nutomic.ensichat.core.Address +import com.nutomic.ensichat.core.util.LocalRoutesInfo._ +import com.typesafe.scalalogging.Logger +import org.joda.time.{DateTime, Duration} + +private[core] object LocalRoutesInfo { + + private val ActiveInterval = Duration.standardSeconds(5) + + /** + * [[RouteStates.Idle]]: + * A route that is known, but has not been used in the last [[ActiveInterval. + * [[RouteStates.Active]]: + * A route that is known, and has been used in the last [[ActiveInterval]]. + * [[RouteStates.Invalid]]: + * A route that has been expired or lost, may not be used for forwarding. + * RouteStates.Unconfirmed is not required as connections are always bidirectional. + */ + object RouteStates extends Enumeration { + type RouteStates = Value + val Idle, Active, Invalid = Value + } + +} + +/** + * This class contains information about routes available to this node. + * + * See AODVv2-13 4.5 (Local Route Set), -> implemented + * 6.9 (Local Route Set Maintenance) -> implemented (hopefully correct) + */ +private[core] class LocalRoutesInfo(activeConnections: () => Set[Address]) { + + import RouteStates._ + + private val MaxSeqnumLifetime = Duration.standardSeconds(300) + // TODO: this can probably be much higher because of infrequent topology changes between internet nodes + private val MaxIdleTime = Duration.standardSeconds(300) + + + /** + * Holds information about a local route. + * + * @param destination The destination address that can be reached with this route. + * @param seqNum Sequence number of the last route message that updated this entry. + * @param nextHop The next hop on the path towards destination. + * @param lastUsed The time this route was last used to forward a message. + * @param lastSeqNumUpdate The time seqNum was last updated. + * @param metric The number of hops towards destination using this route. + * @param state The last known state of the route. + */ + case class RouteEntry(destination: Address, seqNum: Int, nextHop: Address, lastUsed: DateTime, + lastSeqNumUpdate: DateTime, metric: Int, state: RouteStates) + + private var routes = Set[RouteEntry]() + + def addRoute(destination: Address, seqNum: Int, nextHop: Address, metric: Int): Unit = { + val entry = RouteEntry(destination, seqNum, nextHop, new DateTime(0), DateTime.now, metric, Idle) + routes += entry + } + + def getRoute(destination: Address): Option[RouteEntry] = { + if (activeConnections().contains(destination)) + return Option(new RouteEntry(destination, 0, destination, DateTime.now, DateTime.now, 1, Idle)) + + handleTimeouts() + val r = routes.toList + .sortWith(_.metric < _.metric) + .find( r => r.destination == destination && r.state != Invalid) + + if (r.isDefined) + routes = routes -- r + r.get.copy(state = Active, lastUsed = DateTime.now) + r + } + + /** + * + * @param address The address which can't be reached any more. + * @return The set of active destinations that can't be reached anymore. + */ + def connectionClosed(address: Address): Set[Address] = { + handleTimeouts() + + val affectedDestinations = + routes + .filter(r => r.state == Active && (r.nextHop == address || r.destination == address)) + .map(_.destination) + + routes = routes.map { r => + if (r.nextHop == address || r.destination == address) + r.copy(state = Invalid) + else + r + } + + affectedDestinations + } + + private def handleTimeouts(): Unit = { + routes = routes + // Delete routes after max lifetime. + .map { r => + if (DateTime.now.isAfter(r.lastSeqNumUpdate.plus(MaxSeqnumLifetime))) + r.copy(seqNum = 0) + else + r + } + // Set routes to invalid after max idle time. + .map { r => + if (DateTime.now.isAfter(r.lastSeqNumUpdate.plus(MaxIdleTime))) + r.copy(state = Invalid) + else + r + } + } + +} \ No newline at end of file diff --git a/core/src/main/scala/com/nutomic/ensichat/core/util/RouteMessageInfo.scala b/core/src/main/scala/com/nutomic/ensichat/core/util/RouteMessageInfo.scala new file mode 100644 index 0000000..953a751 --- /dev/null +++ b/core/src/main/scala/com/nutomic/ensichat/core/util/RouteMessageInfo.scala @@ -0,0 +1,74 @@ +package com.nutomic.ensichat.core.util + +import com.nutomic.ensichat.core.body.{RouteReply, RouteRequest} +import com.nutomic.ensichat.core.{Address, Message, Router} +import org.joda.time.{DateTime, Duration} + +/** + * Contains information about AODVv2 control messages that have been received. + * + * This class handles Route Request and Route Reply messages (referred to as "route messages"). + * + * See AODVv2-13 4.6 (Multicast Route Message Table), -> implemented + * 6.8 (Surpressing Redundant Messages Using the Multicast Route Message Table) -> implemented (hopefully correct) + */ +private[core] class RouteMessageInfo { + + private val MaxSeqnumLifetime = Duration.standardSeconds(300) + + /** + * @param messageType Either [[RouteRequest.Type]] or [[RouteReply.Type]]. + * @param origAddress Source address of the route message triggering the route request. + * @param targAddress Destination address of the route message triggering the route request. + * @param origSeqNum Sequence number associated with the route to [[origAddress]], if route + * message is an RREQ. + * @param targSeqNum Sequence number associated with the route to [[targAddress]], if present in + * the route message. + * @param metric Metric value received in the route message. + * @param timestamp Last time this entry was updated. + */ + private case class RouteMessageEntry(messageType: Int, origAddress: Address, + targAddress: Address, origSeqNum: Int, targSeqNum: Int, + metric: Int, timestamp: DateTime) + + private var entries = Set[RouteMessageEntry]() + + private def addEntry(msg: Message): Unit = msg.body match { + case rreq: RouteRequest => + entries += new RouteMessageEntry(RouteRequest.Type, msg.header.origin, msg.header.target, + msg.header.seqNum, rreq.targSeqNum, rreq.originMetric, + DateTime.now) + case rrep: RouteReply => + entries += new RouteMessageEntry(RouteReply.Type, msg.header.origin, msg.header.target, + msg.header.seqNum, rrep.originSeqNum, rrep.originMetric, + DateTime.now) + } + + def isMessageRedundant(msg: Message): Boolean = { + handleTimeouts() + val existingEntry = + entries.find { e => + val haveEntry = e.messageType == msg.header.protocolType && + e.origAddress == msg.header.origin && e.targAddress == msg.header.target + + val (metric, seqNumComparison) = msg.body match { + case rreq: RouteRequest => (rreq.originMetric, Router.compare(rreq.originSeqNum, e.origSeqNum)) + case rrep: RouteReply => (rrep.originMetric, Router.compare(rrep.originSeqNum, e.targSeqNum)) + } + val isMetricBetter = e.metric < metric + haveEntry && (seqNumComparison > 0 || (seqNumComparison == 0 && isMetricBetter)) + } + if (existingEntry.isDefined) + entries = entries - existingEntry.get + + addEntry(msg) + + existingEntry.isDefined + } + + private def handleTimeouts(): Unit = { + entries = entries.filter { e => + DateTime.now.isBefore(e.timestamp.plus(MaxSeqnumLifetime)) + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala index d08a215..cbabf95 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/CryptoTest.scala @@ -41,8 +41,9 @@ class CryptoTest extends TestCase { def testEncryptDecrypt(): Unit = { MessageTest.messages.foreach{ m => val encrypted = crypto.encryptAndSign(m, Option(crypto.getLocalPublicKey)) - val decrypted = crypto.verifyAndDecrypt(encrypted, Option(crypto.getLocalPublicKey)) - assertEquals(m.body, decrypted.get.body) + assertTrue(crypto.verify(encrypted, Option(crypto.getLocalPublicKey))) + val decrypted = crypto.decrypt(encrypted) + assertEquals(m.body, decrypted.body) assertEquals(m.header, encrypted.header) } } diff --git a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala index 881d42f..2af4609 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala @@ -58,9 +58,10 @@ class MessageTest extends TestCase { val read = Message.read(new ByteArrayInputStream(bytes)) assertEquals(encrypted.crypto, read.crypto) - val decrypted = crypto.verifyAndDecrypt(read, Option(crypto.getLocalPublicKey)) - assertEquals(m.header, decrypted.get.header) - assertEquals(m.body, decrypted.get.body) + assertTrue(crypto.verify(read, Option(crypto.getLocalPublicKey))) + val decrypted = crypto.decrypt(read) + assertEquals(m.header, decrypted.header) + assertEquals(m.body, decrypted.body) } } diff --git a/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala index 0441212..88ca751 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala @@ -1,34 +1,45 @@ package com.nutomic.ensichat.core +import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.{Date, GregorianCalendar} import com.nutomic.ensichat.core.body.{Text, UserInfo} import com.nutomic.ensichat.core.header.ContentHeader +import com.nutomic.ensichat.core.util.LocalRoutesInfo import junit.framework.TestCase import org.junit.Assert._ class RouterTest extends TestCase { - private def neighbors() = Set[Address](AddressTest.a1, AddressTest.a2, AddressTest.a3) + private def neighbors() = Set[Address](AddressTest.a1, AddressTest.a2, AddressTest.a4) - private val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1) + def testNoRouteFound(): Unit = { + val msg = generateMessage(AddressTest.a2, AddressTest.a3, 1) + val latch = new CountDownLatch(1) + val router = new Router(new LocalRoutesInfo(neighbors), + (_, _) => fail("Message shouldn't be forwarded"), m => { + assertEquals(msg, m) + latch.countDown() + }) + router.forwardMessage(msg) + assertTrue(latch.await(1, TimeUnit.SECONDS)) + } - /** - * Messages should be sent to all neighbors. - */ - def testFlooding(): Unit = { + def testNextHop(): Unit = { + val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1) var sentTo = Set[Address]() - val router: Router = new Router(neighbors, + val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => { sentTo += a - }) + }, _ => ()) router.forwardMessage(msg) - assertEquals(neighbors(), sentTo) + assertEquals(Set(AddressTest.a4), sentTo) } def testMessageSame(): Unit = { - val router: Router = new Router(neighbors, + val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1) + val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => { assertEquals(msg.header.origin, m.header.origin) assertEquals(msg.header.target, m.header.target) @@ -38,7 +49,7 @@ class RouterTest extends TestCase { assertEquals(msg.header.hopLimit, m.header.hopLimit) assertEquals(msg.body, m.body) assertEquals(msg.crypto, m.crypto) - }) + }, _ => ()) router.forwardMessage(msg) } @@ -47,26 +58,32 @@ class RouterTest extends TestCase { */ def testDifferentSenders(): Unit = { var sentTo = Set[Address]() - val router: Router = new Router(neighbors, (a, m) => sentTo += a) + val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => sentTo += a, _ => ()) - router.forwardMessage(msg) - assertEquals(neighbors(), sentTo) + router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, 1)) + assertEquals(Set(AddressTest.a4), sentTo) sentTo = Set[Address]() router.forwardMessage(generateMessage(AddressTest.a2, AddressTest.a4, 1)) - assertEquals(neighbors(), sentTo) + assertEquals(Set(AddressTest.a4), sentTo) + } + + def testSeqNumComparison(): Unit = { + Router.compare(1, ContentHeader.SeqNumRange.last) + Router.compare(ContentHeader.SeqNumRange.last / 2, ContentHeader.SeqNumRange.last) + Router.compare(ContentHeader.SeqNumRange.last / 2, 1) } def testDiscardOldIgnores(): Unit = { def test(first: Int, second: Int) { var sentTo = Set[Address]() - val router: Router = new Router(neighbors, (a, m) => sentTo += a) - router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, first)) - router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, second)) + val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => sentTo += a, _ => ()) + router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, first)) + router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, second)) sentTo = Set[Address]() - router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, first)) - assertEquals(neighbors(), sentTo) + router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, first)) + assertEquals(Set(AddressTest.a4), sentTo) } test(1, ContentHeader.SeqNumRange.last) @@ -77,7 +94,7 @@ class RouterTest extends TestCase { def testHopLimit(): Unit = Range(19, 22).foreach { i => val msg = new Message( new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), i), new Text("")) - val router: Router = new Router(neighbors, (a, m) => fail()) + val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => fail(), _ => ()) router.forwardMessage(msg) } diff --git a/core/src/test/scala/com/nutomic/ensichat/core/body/RouteErrorTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteErrorTest.scala new file mode 100644 index 0000000..51b2002 --- /dev/null +++ b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteErrorTest.scala @@ -0,0 +1,16 @@ +package com.nutomic.ensichat.core.body + +import com.nutomic.ensichat.core.AddressTest +import junit.framework.TestCase +import org.junit.Assert._ + +class RouteErrorTest extends TestCase { + + def testWriteRead(): Unit = { + val rerr = new RouteError(AddressTest.a2, 62000) + val bytes = rerr.write + val parsed = RouteError.read(bytes) + assertEquals(rerr, parsed) + } + +} diff --git a/core/src/test/scala/com/nutomic/ensichat/core/body/RouteReplyTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteReplyTest.scala new file mode 100644 index 0000000..06149d0 --- /dev/null +++ b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteReplyTest.scala @@ -0,0 +1,16 @@ +package com.nutomic.ensichat.core.body + +import com.nutomic.ensichat.core.AddressTest +import junit.framework.TestCase +import org.junit.Assert._ + +class RouteReplyTest extends TestCase { + + def testWriteRead(): Unit = { + val rrep = new RouteReply(61000, 123) + val bytes = rrep.write + val parsed = RouteReply.read(bytes) + assertEquals(rrep, parsed) + } + +} diff --git a/core/src/test/scala/com/nutomic/ensichat/core/body/RouteRequestTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteRequestTest.scala new file mode 100644 index 0000000..b726c2b --- /dev/null +++ b/core/src/test/scala/com/nutomic/ensichat/core/body/RouteRequestTest.scala @@ -0,0 +1,16 @@ +package com.nutomic.ensichat.core.body + +import com.nutomic.ensichat.core.AddressTest +import junit.framework.TestCase +import org.junit.Assert._ + +class RouteRequestTest extends TestCase { + + def testWriteRead(): Unit = { + val rreq = new RouteRequest(AddressTest.a2, 60000, 60001, 60002) + val bytes = rreq.write + val parsed = RouteRequest.read(bytes) + assertEquals(rreq, parsed) + } + +} diff --git a/core/src/test/scala/com/nutomic/ensichat/core/util/LocalRoutesInfoTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/util/LocalRoutesInfoTest.scala new file mode 100644 index 0000000..f5ebba6 --- /dev/null +++ b/core/src/test/scala/com/nutomic/ensichat/core/util/LocalRoutesInfoTest.scala @@ -0,0 +1,45 @@ +package com.nutomic.ensichat.core.util + +import com.nutomic.ensichat.core.AddressTest +import junit.framework.TestCase +import org.joda.time.{DateTime, DateTimeUtils, Duration} +import org.junit.Assert._ + +class LocalRoutesInfoTest extends TestCase { + + private def connections() = Set(AddressTest.a1, AddressTest.a2) + + def testRoute(): Unit = { + val routesInfo = new LocalRoutesInfo(connections) + routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1) + val route = routesInfo.getRoute(AddressTest.a3) + assertEquals(AddressTest.a1, route.get.nextHop) + } + + def testBestMetric(): Unit = { + val routesInfo = new LocalRoutesInfo(connections) + routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1) + routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a2, 2) + val route = routesInfo.getRoute(AddressTest.a3) + assertEquals(AddressTest.a1, route.get.nextHop) + } + + def testConnectionClosed(): Unit = { + val routesInfo = new LocalRoutesInfo(connections) + routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1) + routesInfo.addRoute(AddressTest.a4, 0, AddressTest.a1, 1) + // Mark the route as active, because only active routes are returned. + routesInfo.getRoute(AddressTest.a3) + val unreachable = routesInfo.connectionClosed(AddressTest.a1) + assertEquals(Set(AddressTest.a3), unreachable) + } + + def testTimeout(): Unit = { + DateTimeUtils.setCurrentMillisFixed(new DateTime().getMillis) + val routesInfo = new LocalRoutesInfo(connections) + routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1) + DateTimeUtils.setCurrentMillisFixed(DateTime.now.plus(Duration.standardSeconds(400)).getMillis) + assertEquals(None, routesInfo.getRoute(AddressTest.a3)) + } + +} \ No newline at end of file diff --git a/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala new file mode 100644 index 0000000..bc6e2b5 --- /dev/null +++ b/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala @@ -0,0 +1,79 @@ +package com.nutomic.ensichat.core.util + +import com.nutomic.ensichat.core.body.{RouteReply, RouteRequest} +import com.nutomic.ensichat.core.header.MessageHeader +import com.nutomic.ensichat.core.{AddressTest, Message} +import junit.framework.TestCase +import org.joda.time.{DateTime, DateTimeUtils, Duration} +import org.junit.Assert._ + +class RouteMessageInfoTest extends TestCase { + + /** + * Test case in which we have an entry with the same type, origin and target. + */ + def testSameMessage(): Unit = { + val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val msg = new Message(header, new RouteRequest(AddressTest.a3, 2, 3, 1)) + val rmi = new RouteMessageInfo() + assertFalse(rmi.isMessageRedundant(msg)) + assertTrue(rmi.isMessageRedundant(msg)) + } + + /** + * Forward a message with a seqnum that is older than the latest. + */ + def testSeqNumOlder(): Unit = { + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 0, 0, 0)) + val rmi = new RouteMessageInfo() + assertFalse(rmi.isMessageRedundant(msg1)) + + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3) + val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 2, 0, 0)) + assertTrue(rmi.isMessageRedundant(msg2)) + } + + /** + * Announce a route with a metric that is worse than the existing one. + */ + def testMetricWorse(): Unit = { + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 1, 0, 2)) + val rmi = new RouteMessageInfo() + assertFalse(rmi.isMessageRedundant(msg1)) + + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) + val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 1, 0, 4)) + assertTrue(rmi.isMessageRedundant(msg2)) + } + + /** + * Announce route with a better metric. + */ + def testMetricBetter(): Unit = { + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val msg1 = new Message(header1, new RouteReply(0, 4)) + val rmi = new RouteMessageInfo() + assertFalse(rmi.isMessageRedundant(msg1)) + + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) + val msg2 = new Message(header2, new RouteReply(0, 2)) + assertFalse(rmi.isMessageRedundant(msg2)) + } + + /** + * Test that entries are removed after [[RouteMessageInfo.MaxSeqnumLifetime]]. + */ + def testTimeout(): Unit = { + val rmi = new RouteMessageInfo() + DateTimeUtils.setCurrentMillisFixed(DateTime.now.getMillis) + val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val msg = new Message(header, new RouteRequest(AddressTest.a3, 0, 0, 0)) + assertFalse(rmi.isMessageRedundant(msg)) + + DateTimeUtils.setCurrentMillisFixed(DateTime.now.plus(Duration.standardSeconds(400)).getMillis) + assertFalse(rmi.isMessageRedundant(msg)) + } + +} \ No newline at end of file diff --git a/integration/.gitignore b/integration/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/integration/.gitignore @@ -0,0 +1 @@ +/build diff --git a/integration/build.gradle b/integration/build.gradle new file mode 100644 index 0000000..6736632 --- /dev/null +++ b/integration/build.gradle @@ -0,0 +1,12 @@ +apply plugin: 'scala' +apply plugin: 'application' + +dependencies { + compile 'org.scala-lang:scala-library:2.11.7' + compile 'com.github.scala-incubator.io:scala-io-file_2.11:0.4.3' + compile project(path: ':core') +} + +mainClassName = 'com.nutomic.ensichat.integration.Main' +version = "0.2.3" +applicationName = 'ensichat-server' diff --git a/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala b/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala new file mode 100644 index 0000000..58c58af --- /dev/null +++ b/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala @@ -0,0 +1,84 @@ +package com.nutomic.ensichat.integration + +import java.io.File +import java.util.concurrent.{LinkedBlockingDeque, LinkedBlockingQueue} + +import com.nutomic.ensichat.core.body.{RouteError, RouteRequest, RouteReply} +import com.nutomic.ensichat.core.interfaces.{CallbackInterface, SettingsInterface} +import com.nutomic.ensichat.core.util.Database +import com.nutomic.ensichat.core.{ConnectionHandler, Crypto, Message} +import com.nutomic.ensichat.integration.LocalNode._ + +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scalax.file.Path + +object LocalNode { + + private final val StartingPort = 21000 + + object EventType extends Enumeration { + type EventType = Value + val MessageReceived, ConnectionsChanged, ContactsUpdated = Value + } + + class FifoStream[A]() { + private val queue = new LinkedBlockingQueue[Option[A]]() + def toStream: Stream[A] = queue.take match { + case Some(a) => Stream.cons(a, toStream) + case None => Stream.empty + } + def close() = queue add None + def enqueue(a: A) = queue.put(Option(a)) + } + +} + +/** + * Runs an ensichat node on localhost. + * + * Received messages can be accessed through [[eventQueue]]. + * + * @param index Number of this node. The server port is opened on port [[StartingPort]] + index. + * @param configFolder Folder where keys and configuration should be stored. + */ +class LocalNode(val index: Int, configFolder: File) extends CallbackInterface { + + import com.nutomic.ensichat.integration.LocalNode.EventType._ + private val databaseFile = new File(configFolder, "database") + private val keyFolder = new File(configFolder, "keys") + + private val database = new Database(databaseFile, this) + private val settings = new SettingsInterface { + private var values = Map[String, Any]() + override def get[T](key: String, default: T): T = values.get(key).map(_.asInstanceOf[T]).getOrElse(default) + override def put[T](key: String, value: T): Unit = values += (key -> value.asInstanceOf[Any]) + } + + val crypto = new Crypto(settings, keyFolder) + val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 0, port) + val eventQueue = new FifoStream[(EventType.EventType, Option[Message])]() + + configFolder.mkdirs() + keyFolder.mkdirs() + settings.put(SettingsInterface.KeyAddresses, "") + Await.result(connectionHandler.start(), Duration.Inf) + + def port = StartingPort + index + + def stop(): Unit = { + connectionHandler.stop() + Path(configFolder).deleteRecursively() + } + + def onMessageReceived(msg: Message): Unit = { + eventQueue.enqueue((EventType.MessageReceived, Option(msg))) + } + + def onConnectionsChanged(): Unit = + eventQueue.enqueue((EventType.ConnectionsChanged, None)) + + def onContactsUpdated(): Unit = + eventQueue.enqueue((EventType.ContactsUpdated, None)) + +} diff --git a/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala b/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala new file mode 100644 index 0000000..65ccc90 --- /dev/null +++ b/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala @@ -0,0 +1,135 @@ +package com.nutomic.ensichat.integration + +import java.io.File +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import com.nutomic.ensichat.core.Crypto +import com.nutomic.ensichat.core.body.Text + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future} +import scala.util.Try + +/** + * Creates some local nodes, connects them and sends messages between them. + * + * If the test runs slow or fails, changing [[Crypto.PublicKeySize]] to 512 should help. + */ +object Main extends App { + + val nodes = createMesh() + System.out.println("\n\nAll nodes connected!\n\n") + + sendMessages(nodes) + System.out.println("\n\nAll messages sent!\n\n") + + // Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8. + nodes(1).connectionHandler.stop() + System.out.println("node 1 stopped") + sendMessages(nodes) + + /** + * Creates a new mesh with a predefined layout. + * + * Graphical representation: + * 8 —— 7 + * / \ + * 0———1———3———4 + * \ / | | + * 2 5———6 + * + * @return List of [[LocalNode]]s, ordered from 0 to 7. + */ + private def createMesh(): Seq[LocalNode] = { + val nodes = Await.result(Future.sequence(0.to(8).map(createNode)), Duration.Inf) + sys.addShutdownHook(nodes.foreach(_.stop())) + + connectNodes(nodes(0), nodes(1)) + connectNodes(nodes(0), nodes(2)) + connectNodes(nodes(1), nodes(2)) + connectNodes(nodes(1), nodes(3)) + connectNodes(nodes(3), nodes(4)) + connectNodes(nodes(3), nodes(5)) + connectNodes(nodes(4), nodes(6)) + connectNodes(nodes(5), nodes(6)) + connectNodes(nodes(3), nodes(7)) + connectNodes(nodes(0), nodes(8)) + connectNodes(nodes(7), nodes(8)) + nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}")) + + nodes + } + + private def createNode(index: Int): Future[LocalNode] = { + val configFolder = new File(s"build/node$index/") + assert(!configFolder.exists(), s"stale config exists in $configFolder") + Future(new LocalNode(index, configFolder)) + } + + private def connectNodes(first: LocalNode, second: LocalNode): Unit = { + first.connectionHandler.connect(s"localhost:${second.port}") + + first.eventQueue.toStream.find(_._1 == LocalNode.EventType.ConnectionsChanged) + second.eventQueue.toStream.find(_._1 == LocalNode.EventType.ConnectionsChanged) + + val firstAddress = first.crypto.localAddress + val secondAddress = second.crypto.localAddress + val firstConnections = first.connectionHandler.connections() + val secondConnections = second.connectionHandler.connections() + + assert(firstConnections.contains(secondAddress), + s"${first.index} is not connected to ${second.index}") + assert(secondConnections.contains(firstAddress), + s"${second.index} is not connected to ${second.index}") + + System.out.println(s"${first.index} and ${second.index} connected") + } + + private def sendMessages(nodes: Seq[LocalNode]): Unit = { + sendMessage(nodes(0), nodes(2)) + sendMessage(nodes(2), nodes(0)) + sendMessage(nodes(4), nodes(3)) + sendMessage(nodes(3), nodes(5)) + sendMessage(nodes(4), nodes(6)) + sendMessage(nodes(2), nodes(3)) + sendMessage(nodes(0), nodes(3)) + sendMessage(nodes(3), nodes(6)) + sendMessage(nodes(3), nodes(2)) + } + + + private def sendMessage(from: LocalNode, to: LocalNode): Unit = { + addKey(to.crypto, from.crypto) + addKey(from.crypto, to.crypto) + + System.out.println(s"sendMessage(${from.index}, ${to.index})") + val text = s"${from.index} to ${to.index}" + from.connectionHandler.sendTo(to.crypto.localAddress, new Text(text)) + + val latch = new CountDownLatch(1) + Future { + val exists = + to.eventQueue.toStream.exists { event => + if (event._1 != LocalNode.EventType.MessageReceived) + false + else { + event._2.get.body match { + case t: Text => t.text == text + case _ => false + } + } + } + assert(exists, s"message from ${from.index} did not arrive at ${to.index}") + latch.countDown() + } + assert(latch.await(1000, TimeUnit.MILLISECONDS)) + } + + private def addKey(addTo: Crypto, addFrom: Crypto): Unit = { + if (Try(addTo.getPublicKey(addFrom.localAddress)).isFailure) + addTo.addPublicKey(addFrom.localAddress, addFrom.getLocalPublicKey) + + } + +} \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 65899c0..e61120e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1 +1 @@ -include ':android', ':core', ':server' +include ':android', ':core', ':server', ':integration'