Implemented AODVv2, including integration test (fixes #33).

For documentation on how AODVv2 works, see this link:
https://datatracker.ietf.org/doc/draft-ietf-manet-aodvv2/

Note that this implementation is incompatible with AODVv2 itself,
as various details are changed, and not all features have been
implemented
This commit is contained in:
Felix Ableitner 2016-04-07 13:54:36 +02:00
parent 2cc4928a99
commit 83fc696cc7
35 changed files with 1208 additions and 137 deletions

View file

@ -26,6 +26,9 @@ Nodes MUST NOT have a public key with the broadcast address or null
address as hash. Additionally, nodes MUST NOT connect to a node with address as hash. Additionally, nodes MUST NOT connect to a node with
either address. either address.
All integer fields are in network byte order, and unsigned (unless
specified otherwise).
Crypto Crypto
------ ------
@ -40,22 +43,15 @@ private key, and the result written to the 'Encryption Data' part.
Routing Routing
------- -------
A simple flood routing protocol is currently used. Every node forwards The routing protocol is based on
all messages, unless a message with the same Origin and Sequence Number [AODVv2](https://datatracker.ietf.org/doc/draft-ietf-manet-aodvv2/),
has already been received. with various features left out.
Nodes MUST store pairs of (Origin, Sequence Number) for all received TODO: Add Documentation for routing protocol.
messages. After receiving a new message, entries with the same Origin
and Sequence Number between _received_ + 1 and _received_ + 32767 MUST
be removed (with a wrap around at the maximum value). The entries MUST
NOT be cleared while the program is running. They MAY be cleared when
the program is exited.
There is currently no support for offline messages. If sender and There is currently no support for offline messages. If sender and
receiver are not in the same mesh, the message will not arrive. receiver are not in the same mesh, the message will not arrive.
Nodes are free implement different routing algorithms.
Messages Messages
-------- --------
@ -84,9 +80,7 @@ AES key is wrapped with the recipient's public RSA key.
### Header ### Header
Every message starts with one 74 byte header indicating the message Every message starts with one 74 byte header indicating the message
version, type and ID, followed by the length of the message. The version, type and ID, followed by the length of the message.
header is in network byte order, i.e. big endian. The header may have
6 bytes of additional data.
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
@ -220,6 +214,74 @@ After this message has been received, communication with normal messages
may start. may start.
### Route Request (Protocol-Type = 2)
Sent to request a route to a specific Target Address.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Address (32 bytes) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| OrigSeqNum | OriginMetric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| TargMetric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Equivalent to the Sequence Number in the message header.
Set OrigMetric = RouterClient.Cost for the Router Client entry
which includes OrigAddr.
If an Invalid route exists in the Local Route Set matching
TargAddr using longest prefix matching and has a valid
sequence number, set TargSeqNum = LocalRoute.SeqNum.
Otherwise, set TargSeqNum = -1. This field is signed.
### Route Reply (Protocol-Type = 3)
Sent as a reply when a Route Request arrives, to inform other nodes
about a route.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| TargSeqNum | TargMetric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Set TargMetric = RouterClient.Cost for the Router Client entry
which includes TargAddr.
### Route Error (Protocol-Type = 4)
Notifies other nodes of routes that are no longer available. The target
address MUST be set to the broadcast address.
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Packet Source (32 bytes) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Address (32 bytes) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| SeqNum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Packet Source is the source address of the message triggering this
Route Error. If the route error is not triggered by a message,
this MUST be set to the null address.
Address is the address that is no longer reachable.
SeqNum is the sequence number of the route that is no longer available
(if known). Otherwise, set TargSeqNum = -1. This field is signed.
Content Messages Content Messages
---------------- ----------------

View file

@ -18,10 +18,20 @@ To setup a development environment, just install [Android Studio](https://develo
and import the project. and import the project.
Alternatively, you can use the command line. To create a debug apk, run `./gradlew assembleDevDebug`. Alternatively, you can use the command line. To create a debug apk, run `./gradlew assembleDevDebug`.
This requires at least Android Lollipop on your development device. If you don't have Lollipop, you This requires at least Android Lollipop on your development device. If you don't have 5.0 or higher,
can alternatively use `./gradlew assembleRelDebug`. However, this results in considerably slower you have to use `./gradlew assembleRelDebug`. However, this results in considerably slower
incremental builds. To create a release apk, run `./gradlew assembleRelRelease`. incremental builds. To create a release apk, run `./gradlew assembleRelRelease`.
Testing
-------
You can run the unit tests with `./gradlew test`. After connecting an Android device, you can run
the Android tests with `./gradlew connectedDevDebugAndroidTest` (or
`./gradlew connectedRelDebugAndroidTest` if your Android version is lower than 5.0).
To run integration tests for the core module, use `./gradlew integration:run`. If this fails (or
is very slow), try changing the value of Crypto#PublicKeySize to 512 (in the core module).
License License
------- -------

View file

@ -11,7 +11,7 @@ buildscript {
} }
dependencies { dependencies {
compile 'com.android.support:design:23.1.1' compile 'com.android.support:design:23.4.0'
compile 'com.android.support:multidex:1.0.1' compile 'com.android.support:multidex:1.0.1'
compile 'org.scala-lang:scala-library:2.11.7' compile 'org.scala-lang:scala-library:2.11.7'
compile 'com.mobsandgeeks:adapter-kit:0.5.3' compile 'com.mobsandgeeks:adapter-kit:0.5.3'
@ -46,9 +46,18 @@ android {
testInstrumentationRunner "com.android.test.runner.MultiDexTestRunner" testInstrumentationRunner "com.android.test.runner.MultiDexTestRunner"
} }
buildTypes.debug { buildTypes {
applicationIdSuffix ".debug" debug {
testCoverageEnabled true applicationIdSuffix ".debug"
testCoverageEnabled true
}
release {
// HACK: This shouldn't be needed, but multidex isn't working correctly.
// https://code.google.com/p/android/issues/detail?id=206131
// https://code.google.com/p/android/issues/detail?id=209084
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android.txt'), file('proguard-rules.txt')
}
} }
// Increasing minSdkVersion reduces compilation time for MultiDex. // Increasing minSdkVersion reduces compilation time for MultiDex.

View file

@ -0,0 +1,36 @@
# Add project specific ProGuard rules here.
# By default, the flags in this file are appended to flags specified
# in /home/sg/adt/sdk/tools/proguard/proguard-android.txt
# You can edit the include path and order by changing the ProGuard
# include property in project.properties.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# Add any project specific keep options here:
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
-dontobfuscate
-keep class scala.** { *; }
-keep class slick.** { *; }
-keep class org.joda.time.** { *; }
-keep class org.h2.** { *; }
-keep class java.util.** { *; }
-keepclassmembers class java.util.Comparator {
public *;
}
-dontwarn scala.**
-dontwarn slick.**
-dontwarn org.joda.time.**
-dontwarn org.h2.**
-dontwarn java.util.function.**
-ignorewarnings

View file

@ -170,10 +170,11 @@ class BluetoothInterface(context: Context, mainHandler: Handler,
* Removes device from active connections. * Removes device from active connections.
*/ */
def onConnectionClosed(device: Device, socket: BluetoothSocket): Unit = { def onConnectionClosed(device: Device, socket: BluetoothSocket): Unit = {
val address = getAddressForDevice(device.id)
devices -= device.id devices -= device.id
connections -= device.id connections -= device.id
connectionHandler.onConnectionClosed()
addressDeviceMap = addressDeviceMap.filterNot(_._2 == device.id) addressDeviceMap = addressDeviceMap.filterNot(_._2 == device.id)
connectionHandler.onConnectionClosed(address)
} }
/** /**
@ -192,15 +193,18 @@ class BluetoothInterface(context: Context, mainHandler: Handler,
if (!connectionHandler.onConnectionOpened(msg)) if (!connectionHandler.onConnectionOpened(msg))
addressDeviceMap -= address addressDeviceMap -= address
case _ => case _ =>
connectionHandler.onMessageReceived(msg) connectionHandler.onMessageReceived(msg, getAddressForDevice(device))
} }
private def getAddressForDevice(device: Device.ID) =
addressDeviceMap.find(_._2 == device).get._1
/** /**
* Sends the message to nextHop. * Sends the message to nextHop.
*/ */
override def send(nextHop: Address, msg: Message): Unit = { override def send(nextHop: Address, msg: Message): Unit = {
addressDeviceMap addressDeviceMap
.find(_._1 == nextHop) .find(_._1 == nextHop || Address.Broadcast == nextHop)
.map(i => connections.get(i._2)) .map(i => connections.get(i._2))
.getOrElse(None) .getOrElse(None)
.foreach(_.send(msg)) .foreach(_.send(msg))
@ -210,11 +214,6 @@ class BluetoothInterface(context: Context, mainHandler: Handler,
* Returns all active Bluetooth connections. * Returns all active Bluetooth connections.
*/ */
override def getConnections: Set[Address] = override def getConnections: Set[Address] =
connections.flatMap { x => connections.map( c => getAddressForDevice(c._1)).toSet
addressDeviceMap
.find(_._2 == x._1)
.map(_._1)
}
.toSet
} }

View file

@ -5,6 +5,7 @@ dependencies {
compile 'com.h2database:h2:1.4.191' compile 'com.h2database:h2:1.4.191'
compile 'com.typesafe.slick:slick_2.11:3.1.1' compile 'com.typesafe.slick:slick_2.11:3.1.1'
compile 'com.typesafe.scala-logging:scala-logging_2.11:3.4.0' compile 'com.typesafe.scala-logging:scala-logging_2.11:3.4.0'
compile 'joda-time:joda-time:2.9.3'
testCompile 'junit:junit:4.12' testCompile 'junit:junit:4.12'
} }

View file

@ -3,7 +3,7 @@
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender"> <appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
<target>System.out</target> <target>System.out</target>
<encoder> <encoder>
<pattern>%d{HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n</pattern> <pattern>%d{HH:mm:ss} %level/%logger{0}: %msg%n</pattern>
</encoder> </encoder>
</appender> </appender>

View file

@ -56,4 +56,9 @@ final case class Address(bytes: Array[Byte]) {
.grouped(Address.GroupLength) .grouped(Address.GroupLength)
.reduce(_ + "-" + _) .reduce(_ + "-" + _)
/**
* Returns shortened address, useful for debugging.
*/
def short = toString.split("-").head
} }

View file

@ -1,15 +1,18 @@
package com.nutomic.ensichat.core package com.nutomic.ensichat.core
import java.security.InvalidKeyException
import java.util.Date import java.util.Date
import com.nutomic.ensichat.core.body.{ConnectionInfo, MessageBody, UserInfo} import com.nutomic.ensichat.core.body._
import com.nutomic.ensichat.core.header.ContentHeader import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader}
import com.nutomic.ensichat.core.interfaces._ import com.nutomic.ensichat.core.interfaces._
import com.nutomic.ensichat.core.internet.InternetInterface import com.nutomic.ensichat.core.internet.InternetInterface
import com.nutomic.ensichat.core.util.{Database, FutureHelper} import com.nutomic.ensichat.core.util.{Database, FutureHelper, LocalRoutesInfo, RouteMessageInfo}
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._
/** /**
* High-level handling of all message transfers and callbacks. * High-level handling of all message transfers and callbacks.
@ -19,16 +22,33 @@ import scala.concurrent.ExecutionContext.Implicits.global
*/ */
final class ConnectionHandler(settings: SettingsInterface, database: Database, final class ConnectionHandler(settings: SettingsInterface, database: Database,
callbacks: CallbackInterface, crypto: Crypto, callbacks: CallbackInterface, crypto: Crypto,
maxInternetConnections: Int) { maxInternetConnections: Int,
port: Int = InternetInterface.DefaultPort) {
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private val MissingRouteMessageTimeout = 5.minutes
private var transmissionInterfaces = Set[TransmissionInterface]() private var transmissionInterfaces = Set[TransmissionInterface]()
private lazy val router = new Router(connections, sendVia)
private lazy val seqNumGenerator = new SeqNumGenerator(settings) private lazy val seqNumGenerator = new SeqNumGenerator(settings)
private val localRoutesInfo = new LocalRoutesInfo(connections)
private val routeMessageInfo = new RouteMessageInfo()
private lazy val router = new Router(localRoutesInfo,
(a, m) => transmissionInterfaces.foreach(_.send(a, m)),
noRouteFound)
/**
* Contains messages that couldn't be forwarded because we don't know a route.
*
* These will be buffered until we receive a [[RouteReply]] for the target, or when until the
* message has couldn't be forwarded after [[MissingRouteMessageTimeout]].
*/
private var missingRouteMessages = Set[(Message, Date)]()
/** /**
* Holds all known users. * Holds all known users.
* *
@ -42,14 +62,15 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
* @param additionalInterfaces Instances of [[TransmissionInterface]] to transfer data over * @param additionalInterfaces Instances of [[TransmissionInterface]] to transfer data over
* platform specific interfaces (eg Bluetooth). * platform specific interfaces (eg Bluetooth).
*/ */
def start(additionalInterfaces: Set[TransmissionInterface] = Set()): Unit = { def start(additionalInterfaces: Set[TransmissionInterface] = Set()): Future[Unit] = {
additionalInterfaces.foreach(transmissionInterfaces += _) additionalInterfaces.foreach(transmissionInterfaces += _)
FutureHelper { FutureHelper {
crypto.generateLocalKeys() crypto.generateLocalKeys()
logger.info("Service started, address is " + crypto.localAddress) logger.info("Service started, address is " + crypto.localAddress)
logger.info("Local user is " + settings.get(SettingsInterface.KeyUserName, "none") + logger.info("Local user is " + settings.get(SettingsInterface.KeyUserName, "none") +
" with status '" + settings.get(SettingsInterface.KeyUserStatus, "") + "'") " with status '" + settings.get(SettingsInterface.KeyUserStatus, "") + "'")
transmissionInterfaces += new InternetInterface(this, crypto, settings, maxInternetConnections) transmissionInterfaces +=
new InternetInterface(this, crypto, settings, maxInternetConnections, port)
transmissionInterfaces.foreach(_.create()) transmissionInterfaces.foreach(_.create())
} }
} }
@ -63,6 +84,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
* Sends a new message to the given target address. * Sends a new message to the given target address.
*/ */
def sendTo(target: Address, body: MessageBody): Unit = { def sendTo(target: Address, body: MessageBody): Unit = {
assert(body.contentType != -1)
FutureHelper { FutureHelper {
val messageId = settings.get("message_id", 0L) val messageId = settings.get("message_id", 0L)
val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(), val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(),
@ -76,23 +98,165 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
} }
} }
private def sendVia(nextHop: Address, msg: Message) = private def requestRoute(target: Address): Unit = {
transmissionInterfaces.foreach(_.send(nextHop, msg)) assert(localRoutesInfo.getRoute(target).isEmpty)
val seqNum = seqNumGenerator.next()
val targetSeqNum = localRoutesInfo.getRoute(target).map(_.seqNum).getOrElse(-1)
val body = new RouteRequest(target, seqNum, targetSeqNum, 0)
val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum)
val signed = crypto.sign(new Message(header, body))
router.forwardMessage(signed)
}
private def replyRoute(target: Address, replyTo: Address): Unit = {
val seqNum = seqNumGenerator.next()
val body = new RouteReply(seqNum, 0)
val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum)
val signed = crypto.sign(new Message(header, body))
router.forwardMessage(signed)
}
private def routeError(address: Address, packetSource: Option[Address]): Unit = {
val destination = packetSource.getOrElse(Address.Broadcast)
val header = new MessageHeader(RouteError.Type, crypto.localAddress, destination,
seqNumGenerator.next())
val seqNum = localRoutesInfo.getRoute(address).map(_.seqNum).getOrElse(-1)
val body = new RouteError(address, seqNum)
val signed = crypto.sign(new Message(header, body))
router.forwardMessage(signed)
}
/**
* Force connect to a sepcific internet.
*
* @param address An address in the format IP;port or hostname:port.
*/
def connect(address: String): Unit = {
transmissionInterfaces
.find(_.isInstanceOf[InternetInterface])
.map(_.asInstanceOf[InternetInterface])
.foreach(_.openConnection(address))
}
/** /**
* Decrypts and verifies incoming messages, forwards valid ones to [[onNewMessage()]]. * Decrypts and verifies incoming messages, forwards valid ones to [[onNewMessage()]].
*/ */
def onMessageReceived(msg: Message): Unit = { def onMessageReceived(msg: Message, previousHop: Address): Unit = {
if (router.isMessageSeen(msg)) { if (router.isMessageSeen(msg)) {
logger.trace("Ignoring message from " + msg.header.origin + " that we already received") logger.trace("Ignoring message from " + msg.header.origin + " that we already received")
} else if (msg.header.target == crypto.localAddress) { return
crypto.verifyAndDecrypt(msg) match {
case Some(m) => onNewMessage(m)
case None => logger.info("Ignoring message with invalid signature from " + msg.header.origin)
}
} else {
router.forwardMessage(msg)
} }
msg.body match {
case rreq: RouteRequest =>
localRoutesInfo.addRoute(msg.header.origin, rreq.originSeqNum, previousHop, rreq.originMetric)
// TODO: Respecting this causes the RERR test to fail. We have to fix the implementation
// of isMessageRedundant() without breaking the test.
if (routeMessageInfo.isMessageRedundant(msg)) {
logger.info("Sending redundant RREQ")
//return
}
if (crypto.localAddress == rreq.requested)
replyRoute(rreq.requested, msg.header.origin)
else {
val body = rreq.copy(originMetric = rreq.originMetric + 1)
val forwardMsg = crypto.sign(new Message(msg.header, body))
localRoutesInfo.getRoute(rreq.requested) match {
case Some(route) => router.forwardMessage(forwardMsg, Option(route.nextHop))
case None => router.forwardMessage(forwardMsg, Option(Address.Broadcast))
}
}
return
case rrep: RouteReply =>
localRoutesInfo.addRoute(msg.header.origin, rrep.originSeqNum, previousHop, 0)
// TODO: See above (in RREQ handler).
if (routeMessageInfo.isMessageRedundant(msg)) {
logger.debug("Sending redundant RREP")
//return
}
resendMissingRouteMessages()
if (msg.header.target == crypto.localAddress)
return
val existingRoute = localRoutesInfo.getRoute(msg.header.target)
val states = Set(LocalRoutesInfo.RouteStates.Active, LocalRoutesInfo.RouteStates.Idle)
if (existingRoute.isEmpty || !states.contains(existingRoute.get.state)) {
routeError(msg.header.target, Option(msg.header.origin))
return
}
val body = rrep.copy(originMetric = rrep.originMetric + 1)
val forwardMsg = crypto.sign(new Message(msg.header, body))
router.forwardMessage(forwardMsg)
return
case rerr: RouteError =>
localRoutesInfo.getRoute(rerr.address).foreach { route =>
if (route.nextHop == msg.header.origin && (rerr.seqNum == 0 || rerr.seqNum > route.seqNum)) {
localRoutesInfo.connectionClosed(rerr.address)
.foreach(routeError(_, None))
}
}
case _ =>
}
if (msg.header.target != crypto.localAddress) {
router.forwardMessage(msg)
return
}
val plainMsg =
try {
if (!crypto.verify(msg)) {
logger.warn(s"Received message with invalid signature from ${msg.header.origin}")
return
}
if (msg.header.isContentMessage)
crypto.decrypt(msg)
else
msg
} catch {
case e: InvalidKeyException =>
logger.warn(s"Failed to verify or decrypt message $msg", e)
return
}
onNewMessage(plainMsg)
}
/**
* Tries to send messages in [[missingRouteMessages]] again, after we acquired a new route.
*
* Before checking [[missingRouteMessages]], those older than [[MissingRouteMessageTimeout]]
* are removed.
*/
private def resendMissingRouteMessages(): Unit = {
// resend messages if possible
val date = new Date()
missingRouteMessages = missingRouteMessages.filter { e =>
val removeTime = new Date(e._2.getTime + MissingRouteMessageTimeout.toMillis)
removeTime.after(date)
}
val m = missingRouteMessages.filter(m => localRoutesInfo.getRoute(m._1.header.target).isDefined)
m.foreach( m => router.forwardMessage(m._1))
missingRouteMessages --= m
}
private def noRouteFound(message: Message): Unit = {
if (message.header.origin == crypto.localAddress) {
missingRouteMessages += ((message, new Date()))
requestRoute(message.header.target)
} else
routeError(message.header.target, Option(message.header.origin))
} }
/** /**
@ -163,7 +327,11 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
true true
} }
def onConnectionClosed() = callbacks.onConnectionsChanged() def onConnectionClosed(address: Address): Unit = {
localRoutesInfo.connectionClosed(address)
.foreach(routeError(_, None))
callbacks.onConnectionsChanged()
}
def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections) def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections)
@ -177,6 +345,9 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
.find(_.address == address) .find(_.address == address)
.getOrElse(new User(address, address.toString(), "")) .getOrElse(new User(address, address.toString(), ""))
/**
* This method should be called when the local device's internet connection has changed in any way.
*/
def internetConnectionChanged(): Unit = { def internetConnectionChanged(): Unit = {
transmissionInterfaces transmissionInterfaces
.find(_.isInstanceOf[InternetInterface]) .find(_.isInstanceOf[InternetInterface])

View file

@ -119,7 +119,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) {
* @throws RuntimeException If the key does not exist. * @throws RuntimeException If the key does not exist.
*/ */
@throws[RuntimeException] @throws[RuntimeException]
private[core] def getPublicKey(address: Address): PublicKey = { def getPublicKey(address: Address): PublicKey = {
loadKey(address.toString, classOf[PublicKey]) loadKey(address.toString, classOf[PublicKey])
} }
@ -129,7 +129,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) {
* @throws RuntimeException If a key already exists for this address. * @throws RuntimeException If a key already exists for this address.
*/ */
@throws[RuntimeException] @throws[RuntimeException]
private[core] def addPublicKey(address: Address, key: PublicKey): Unit = { def addPublicKey(address: Address, key: PublicKey): Unit = {
if (havePublicKey(address)) if (havePublicKey(address))
throw new RuntimeException("Already have key for " + address + ", not overwriting") throw new RuntimeException("Already have key for " + address + ", not overwriting")
@ -232,20 +232,6 @@ class Crypto(settings: SettingsInterface, keyFolder: File) {
sign(encrypt(msg, key)) sign(encrypt(msg, key))
} }
private[core] def verifyAndDecrypt(msg: Message, key: Option[PublicKey] = None): Option[Message] = {
// Catch exception to avoid crash if we receive invalid message.
try {
if (verify(msg, key))
Option(decrypt(msg))
else
None
} catch {
case e: InvalidKeyException =>
logger.warn("Failed to verify or decrypt message", e)
None
}
}
private def encrypt(msg: Message, key: Option[PublicKey] = None): Message = { private def encrypt(msg: Message, key: Option[PublicKey] = None): Message = {
// Symmetric encryption of data // Symmetric encryption of data
val secretKey = makeSecretKey() val secretKey = makeSecretKey()
@ -263,7 +249,7 @@ class Crypto(settings: SettingsInterface, keyFolder: File) {
} }
@throws[InvalidKeyException] @throws[InvalidKeyException]
private def decrypt(msg: Message): Message = { def decrypt(msg: Message): Message = {
// Asymmetric decryption of secret key // Asymmetric decryption of secret key
val asymmetricCipher = Cipher.getInstance(CipherAlgorithm) val asymmetricCipher = Cipher.getInstance(CipherAlgorithm)
asymmetricCipher.init(Cipher.UNWRAP_MODE, loadKey(PrivateKeyAlias, classOf[PrivateKey])) asymmetricCipher.init(Cipher.UNWRAP_MODE, loadKey(PrivateKeyAlias, classOf[PrivateKey]))

View file

@ -3,7 +3,7 @@ package com.nutomic.ensichat.core
import java.io.InputStream import java.io.InputStream
import java.security.spec.InvalidKeySpecException import java.security.spec.InvalidKeySpecException
import com.nutomic.ensichat.core.body.{ConnectionInfo, CryptoData, EncryptedBody, MessageBody} import com.nutomic.ensichat.core.body._
import com.nutomic.ensichat.core.header.{AbstractHeader, ContentHeader, MessageHeader} import com.nutomic.ensichat.core.header.{AbstractHeader, ContentHeader, MessageHeader}
object Message { object Message {
@ -50,6 +50,9 @@ object Message {
val body = val body =
header.protocolType match { header.protocolType match {
case ConnectionInfo.Type => ConnectionInfo.read(remaining) case ConnectionInfo.Type => ConnectionInfo.read(remaining)
case RouteRequest.Type => RouteRequest.read(remaining)
case RouteReply.Type => RouteReply.read(remaining)
case RouteError.Type => RouteError.read(remaining)
case _ => new EncryptedBody(remaining) case _ => new EncryptedBody(remaining)
} }
@ -80,6 +83,11 @@ case class Message(header: AbstractHeader, crypto: CryptoData, body: MessageBody
def this(header: AbstractHeader, body: MessageBody) = def this(header: AbstractHeader, body: MessageBody) =
this(header, new CryptoData(None, None), body) this(header, new CryptoData(None, None), body)
def write = header.write(body.length + crypto.length) ++ crypto.write ++ body.write def write = {
header.write(body.length + crypto.length) ++ crypto.write ++ body.write
}
override def toString =
s"Message(${header.origin.short}(${header.seqNum}) -> ${header.target.short}: $body)"
} }

View file

@ -1,18 +1,44 @@
package com.nutomic.ensichat.core package com.nutomic.ensichat.core
import java.util.Comparator
import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader} import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader}
import com.nutomic.ensichat.core.util.LocalRoutesInfo
object Router extends Comparator[Int] {
/**
* Compares which sequence number is newer.
*
* @return 1 if lhs is newer, -1 if rhs is newer, 0 if they are equal.
*/
override def compare(lhs: Int, rhs: Int): Int = {
if (lhs == rhs)
0
// True if [[rhs]] is between {{{MessageHeader.SeqNumRange.size / 2}}} and
// [[MessageHeader.SeqNumRange.size]].
else if (lhs > ContentHeader.SeqNumRange.size / 2) {
// True if [[rhs]] is between {{{lhs - MessageHeader.SeqNumRange.size / 2}}} and [[lhs]].
if (lhs - ContentHeader.SeqNumRange.size / 2 < rhs && rhs < lhs) 1 else -1
} else {
// True if [[rhs]] is *not* between [[lhs]] and {{{lhs + MessageHeader.SeqNumRange.size / 2}}}.
if (rhs < lhs || rhs > lhs + ContentHeader.SeqNumRange.size / 2) 1 else -1
}
}
}
/** /**
* Forwards messages to all connected devices. * Forwards messages to all connected devices.
*/ */
final private[core] class Router(activeConnections: () => Set[Address], send: (Address, Message) => Unit) { private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message) => Unit,
noRouteFound: (Message) => Unit) {
private var messageSeen = Set[(Address, Int)]() private var messageSeen = Set[(Address, Int)]()
/** /**
* Returns true if we have received the same message before. * Returns true if we have received the same message before.
*/ */
def isMessageSeen(msg: Message): Boolean = { private[core] def isMessageSeen(msg: Message): Boolean = {
val info = (msg.header.origin, msg.header.seqNum) val info = (msg.header.origin, msg.header.seqNum)
val seen = messageSeen.contains(info) val seen = messageSeen.contains(info)
markMessageSeen(info) markMessageSeen(info)
@ -23,15 +49,24 @@ final private[core] class Router(activeConnections: () => Set[Address], send: (A
* Sends message to all connected devices. Should only be called if [[isMessageSeen()]] returns * Sends message to all connected devices. Should only be called if [[isMessageSeen()]] returns
* true. * true.
*/ */
def forwardMessage(msg: Message): Unit = { def forwardMessage(msg: Message, nextHopOption: Option[Address] = None): Unit = {
val info = (msg.header.origin, msg.header.seqNum) if (msg.header.hopCount + 1 >= msg.header.hopLimit)
val updated = incHopCount(msg)
if (updated.header.hopCount >= updated.header.hopLimit)
return return
activeConnections().foreach(a => send(a, updated)) val nextHop = nextHopOption.getOrElse(msg.header.target)
markMessageSeen(info) if (nextHop == Address.Broadcast) {
send(nextHop, msg)
return
}
routesInfo.getRoute(nextHop).map(_.nextHop) match {
case Some(a) =>
send(a, incHopCount(msg))
markMessageSeen((msg.header.origin, msg.header.seqNum))
case None =>
noRouteFound(msg)
}
} }
private def markMessageSeen(info: (Address, Int)): Unit = { private def markMessageSeen(info: (Address, Int)): Unit = {
@ -64,15 +99,8 @@ final private[core] class Router(activeConnections: () => Set[Address], send: (A
if (a1 != a2) if (a1 != a2)
true true
// True if [[s2]] is between {{{MessageHeader.SeqNumRange.size / 2}}} and else
// [[MessageHeader.SeqNumRange.size]]. Router.compare(s1, s2) > 0
if (s1 > ContentHeader.SeqNumRange.size / 2) {
// True if [[s2]] is between {{{s1 - MessageHeader.SeqNumRange.size / 2}}} and [[s1]].
s1 - ContentHeader.SeqNumRange.size / 2 < s2 && s2 < s1
} else {
// True if [[s2]] is *not* between [[s1]] and {{{s1 + MessageHeader.SeqNumRange.size / 2}}}.
s2 < s1 || s2 > s1 + ContentHeader.SeqNumRange.size / 2
}
} }
} }

View file

@ -12,4 +12,6 @@ final case class EncryptedBody(data: Array[Byte]) extends MessageBody {
def write = data def write = data
override def length = data.length override def length = data.length
override def toString = "EncryptedBody"
} }

View file

@ -0,0 +1,39 @@
package com.nutomic.ensichat.core.body
import java.nio.ByteBuffer
import com.nutomic.ensichat.core.Address
import com.nutomic.ensichat.core.util.BufferUtils
private[core] object RouteError {
val Type = 4
/**
* Constructs [[RouteError]] instance from byte array.
*/
def read(array: Array[Byte]): RouteError = {
val b = ByteBuffer.wrap(array)
val address = new Address(BufferUtils.getByteArray(b, Address.Length))
val seqNum = b.getInt
new RouteError(address, seqNum)
}
}
private[core] case class RouteError(address: Address, seqNum: Int) extends MessageBody {
override def protocolType = RouteReply.Type
override def contentType = -1
override def write: Array[Byte] = {
val b = ByteBuffer.allocate(length)
b.put(address.bytes)
b.putInt(seqNum)
b.array()
}
override def length = Address.Length + 4
}

View file

@ -0,0 +1,50 @@
package com.nutomic.ensichat.core.body
import java.nio.ByteBuffer
import com.nutomic.ensichat.core.util.BufferUtils
private[core] object RouteReply {
val Type = 3
/**
* Constructs [[RouteReply]] instance from byte array.
*/
def read(array: Array[Byte]): RouteReply = {
val b = ByteBuffer.wrap(array)
val targSeqNum = BufferUtils.getUnsignedShort(b)
val targMetric = BufferUtils.getUnsignedShort(b)
new RouteReply(targSeqNum, targMetric)
}
}
/**
* Sends information about a route.
*
* Note that the fields are named different than described in AODVv2. There, targSeqNum and
* targMetric are used to describe the seqNum and metric of the node sending the route reply. In
* Ensichat, we use originSeqNum and originMetric instead, to stay consistent with the header
* fields. That means header.origin, originSeqNum and originMetric all refer to the node sending
* this message.
*
* @param originSeqNum The current sequence number of the node sending this message.
* @param originMetric The metric of the current route to the sending node.
*/
private[core] case class RouteReply(originSeqNum: Int, originMetric: Int) extends MessageBody {
override def protocolType = RouteReply.Type
override def contentType = -1
override def write: Array[Byte] = {
val b = ByteBuffer.allocate(length)
BufferUtils.putUnsignedShort(b, originSeqNum)
BufferUtils.putUnsignedShort(b, originMetric)
b.array()
}
override def length = 4
}

View file

@ -0,0 +1,44 @@
package com.nutomic.ensichat.core.body
import java.nio.ByteBuffer
import com.nutomic.ensichat.core.Address
import com.nutomic.ensichat.core.util.BufferUtils
private[core] object RouteRequest {
val Type = 2
/**
* Constructs [[RouteRequest]] instance from byte array.
*/
def read(array: Array[Byte]): RouteRequest = {
val b = ByteBuffer.wrap(array)
val requested = new Address(BufferUtils.getByteArray(b, Address.Length))
val origSeqNum = BufferUtils.getUnsignedShort(b)
val originMetric = BufferUtils.getUnsignedShort(b)
val targSeqNum = b.getInt()
new RouteRequest(requested, origSeqNum, targSeqNum, originMetric)
}
}
private[core] case class RouteRequest(requested: Address, originSeqNum: Int, targSeqNum: Int, originMetric: Int)
extends MessageBody {
override def protocolType = RouteRequest.Type
override def contentType = -1
override def write: Array[Byte] = {
val b = ByteBuffer.allocate(length)
b.put(requested.bytes)
BufferUtils.putUnsignedShort(b, originSeqNum)
BufferUtils.putUnsignedShort(b, originMetric)
b.putInt(targSeqNum)
b.array()
}
override def length = 8 + Address.Length
}

View file

@ -12,8 +12,10 @@ import com.typesafe.scalalogging.Logger
/** /**
* Encapsulates an active connection to another node. * Encapsulates an active connection to another node.
*/ */
class InternetConnectionThread(socket: Socket, crypto: Crypto, onDisconnected: (InternetConnectionThread) => Unit, private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto,
onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { onDisconnected: (InternetConnectionThread) => Unit,
onReceive: (Message, InternetConnectionThread) => Unit)
extends Thread {
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
@ -78,7 +80,6 @@ class InternetConnectionThread(socket: Socket, crypto: Crypto, onDisconnected: (
} catch { } catch {
case e: IOException => logger.warn("Failed to close socket", e) case e: IOException => logger.warn("Failed to close socket", e)
} }
logger.debug("Connection to " + socket.getInetAddress + " closed")
onDisconnected(this) onDisconnected(this)
} }

View file

@ -12,9 +12,9 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future import scala.concurrent.Future
import scala.util.Random import scala.util.Random
object InternetInterface { private[core] object InternetInterface {
val ServerPort = 26344 val DefaultPort = 26344
} }
@ -23,14 +23,14 @@ object InternetInterface {
* *
* @param maxConnections Maximum number of concurrent connections that should be opened. * @param maxConnections Maximum number of concurrent connections that should be opened.
*/ */
class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto, private[core] class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto,
settings: SettingsInterface, maxConnections: Int) settings: SettingsInterface, maxConnections: Int, port: Int)
extends TransmissionInterface { extends TransmissionInterface {
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private lazy val serverThread = private lazy val serverThread =
new InternetServerThread(crypto, onConnected, onDisconnected, onReceiveMessage) new InternetServerThread(crypto, port, onConnected, onDisconnected, onReceiveMessage)
private var connections = Set[InternetConnectionThread]() private var connections = Set[InternetConnectionThread]()
@ -44,10 +44,8 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto,
.replace("46.101.249.188:26344", SettingsInterface.DefaultAddresses) .replace("46.101.249.188:26344", SettingsInterface.DefaultAddresses)
settings.put(SettingsInterface.KeyAddresses, servers) settings.put(SettingsInterface.KeyAddresses, servers)
FutureHelper { serverThread.start()
serverThread.start() openAllConnections(maxConnections)
openAllConnections(maxConnections)
}
} }
/** /**
@ -69,13 +67,13 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto,
.foreach(openConnection) .foreach(openConnection)
} }
private def openConnection(addressPort: String): Unit = { def openConnection(addressPort: String): Unit = {
val (address, port) = val (address, port) =
if (addressPort.contains(":")) { if (addressPort.contains(":")) {
val split = addressPort.split(":") val split = addressPort.split(":")
(split(0), split(1).toInt) (split(0), split(1).toInt)
} else } else
(addressPort, InternetInterface.ServerPort) (addressPort, InternetInterface.DefaultPort)
openConnection(address, port) openConnection(address, port)
} }
@ -100,11 +98,11 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto,
} }
private def onDisconnected(connectionThread: InternetConnectionThread): Unit = { private def onDisconnected(connectionThread: InternetConnectionThread): Unit = {
addressDeviceMap.find(_._2 == connectionThread).foreach { ad => getAddressForThread(connectionThread).foreach { ad =>
logger.trace("Connection closed to " + ad._1) logger.trace("Connection closed to " + ad)
connections -= connectionThread connections -= connectionThread
addressDeviceMap -= ad._1 addressDeviceMap -= ad
connectionHandler.onConnectionClosed() connectionHandler.onConnectionClosed(ad)
} }
} }
@ -122,15 +120,18 @@ class InternetInterface(connectionHandler: ConnectionHandler, crypto: Crypto,
if (!connectionHandler.onConnectionOpened(msg)) if (!connectionHandler.onConnectionOpened(msg))
addressDeviceMap -= address addressDeviceMap -= address
case _ => case _ =>
connectionHandler.onMessageReceived(msg) connectionHandler.onMessageReceived(msg, getAddressForThread(thread).get)
} }
private def getAddressForThread(thread: InternetConnectionThread) =
addressDeviceMap.find(_._2 == thread).map(_._1)
/** /**
* Sends the message to nextHop. * Sends the message to nextHop.
*/ */
override def send(nextHop: Address, msg: Message): Unit = { override def send(nextHop: Address, msg: Message): Unit = {
addressDeviceMap addressDeviceMap
.find(_._1 == nextHop) .filter(_._1 == nextHop || Address.Broadcast == nextHop)
.foreach(_._2.send(msg)) .foreach(_._2.send(msg))
} }

View file

@ -6,13 +6,15 @@ import java.net.ServerSocket
import com.nutomic.ensichat.core.{Crypto, Message} import com.nutomic.ensichat.core.{Crypto, Message}
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
class InternetServerThread(crypto: Crypto, onConnected: (InternetConnectionThread) => Unit, class InternetServerThread(crypto: Crypto, port: Int,
onDisconnected: (InternetConnectionThread) => Unit, onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { onConnected: (InternetConnectionThread) => Unit,
onDisconnected: (InternetConnectionThread) => Unit,
onReceive: (Message, InternetConnectionThread) => Unit) extends Thread {
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private lazy val socket: Option[ServerSocket] = try { private lazy val socket: Option[ServerSocket] = try {
Option(new ServerSocket(InternetInterface.ServerPort)) Option(new ServerSocket(port))
} catch { } catch {
case e: IOException => case e: IOException =>
logger.warn("Failed to create server socket", e) logger.warn("Failed to create server socket", e)

View file

@ -75,8 +75,9 @@ class Database(path: File, callbackInterface: CallbackInterface) {
/** /**
* Inserts the given new message into the database. * Inserts the given new message into the database.
*/ */
def onMessageReceived(msg: Message): Unit = { def onMessageReceived(msg: Message): Unit = msg.body match {
Await.result(db.run(messages += msg), Duration.Inf) case _: Text => Await.result(db.run(messages += msg), Duration.Inf)
case _ =>
} }
def getMessages(address: Address): Seq[Message] = { def getMessages(address: Address): Seq[Message] = {

View file

@ -0,0 +1,119 @@
package com.nutomic.ensichat.core.util
import com.nutomic.ensichat.core.Address
import com.nutomic.ensichat.core.util.LocalRoutesInfo._
import com.typesafe.scalalogging.Logger
import org.joda.time.{DateTime, Duration}
private[core] object LocalRoutesInfo {
private val ActiveInterval = Duration.standardSeconds(5)
/**
* [[RouteStates.Idle]]:
* A route that is known, but has not been used in the last [[ActiveInterval.
* [[RouteStates.Active]]:
* A route that is known, and has been used in the last [[ActiveInterval]].
* [[RouteStates.Invalid]]:
* A route that has been expired or lost, may not be used for forwarding.
* RouteStates.Unconfirmed is not required as connections are always bidirectional.
*/
object RouteStates extends Enumeration {
type RouteStates = Value
val Idle, Active, Invalid = Value
}
}
/**
* This class contains information about routes available to this node.
*
* See AODVv2-13 4.5 (Local Route Set), -> implemented
* 6.9 (Local Route Set Maintenance) -> implemented (hopefully correct)
*/
private[core] class LocalRoutesInfo(activeConnections: () => Set[Address]) {
import RouteStates._
private val MaxSeqnumLifetime = Duration.standardSeconds(300)
// TODO: this can probably be much higher because of infrequent topology changes between internet nodes
private val MaxIdleTime = Duration.standardSeconds(300)
/**
* Holds information about a local route.
*
* @param destination The destination address that can be reached with this route.
* @param seqNum Sequence number of the last route message that updated this entry.
* @param nextHop The next hop on the path towards destination.
* @param lastUsed The time this route was last used to forward a message.
* @param lastSeqNumUpdate The time seqNum was last updated.
* @param metric The number of hops towards destination using this route.
* @param state The last known state of the route.
*/
case class RouteEntry(destination: Address, seqNum: Int, nextHop: Address, lastUsed: DateTime,
lastSeqNumUpdate: DateTime, metric: Int, state: RouteStates)
private var routes = Set[RouteEntry]()
def addRoute(destination: Address, seqNum: Int, nextHop: Address, metric: Int): Unit = {
val entry = RouteEntry(destination, seqNum, nextHop, new DateTime(0), DateTime.now, metric, Idle)
routes += entry
}
def getRoute(destination: Address): Option[RouteEntry] = {
if (activeConnections().contains(destination))
return Option(new RouteEntry(destination, 0, destination, DateTime.now, DateTime.now, 1, Idle))
handleTimeouts()
val r = routes.toList
.sortWith(_.metric < _.metric)
.find( r => r.destination == destination && r.state != Invalid)
if (r.isDefined)
routes = routes -- r + r.get.copy(state = Active, lastUsed = DateTime.now)
r
}
/**
*
* @param address The address which can't be reached any more.
* @return The set of active destinations that can't be reached anymore.
*/
def connectionClosed(address: Address): Set[Address] = {
handleTimeouts()
val affectedDestinations =
routes
.filter(r => r.state == Active && (r.nextHop == address || r.destination == address))
.map(_.destination)
routes = routes.map { r =>
if (r.nextHop == address || r.destination == address)
r.copy(state = Invalid)
else
r
}
affectedDestinations
}
private def handleTimeouts(): Unit = {
routes = routes
// Delete routes after max lifetime.
.map { r =>
if (DateTime.now.isAfter(r.lastSeqNumUpdate.plus(MaxSeqnumLifetime)))
r.copy(seqNum = 0)
else
r
}
// Set routes to invalid after max idle time.
.map { r =>
if (DateTime.now.isAfter(r.lastSeqNumUpdate.plus(MaxIdleTime)))
r.copy(state = Invalid)
else
r
}
}
}

View file

@ -0,0 +1,74 @@
package com.nutomic.ensichat.core.util
import com.nutomic.ensichat.core.body.{RouteReply, RouteRequest}
import com.nutomic.ensichat.core.{Address, Message, Router}
import org.joda.time.{DateTime, Duration}
/**
* Contains information about AODVv2 control messages that have been received.
*
* This class handles Route Request and Route Reply messages (referred to as "route messages").
*
* See AODVv2-13 4.6 (Multicast Route Message Table), -> implemented
* 6.8 (Surpressing Redundant Messages Using the Multicast Route Message Table) -> implemented (hopefully correct)
*/
private[core] class RouteMessageInfo {
private val MaxSeqnumLifetime = Duration.standardSeconds(300)
/**
* @param messageType Either [[RouteRequest.Type]] or [[RouteReply.Type]].
* @param origAddress Source address of the route message triggering the route request.
* @param targAddress Destination address of the route message triggering the route request.
* @param origSeqNum Sequence number associated with the route to [[origAddress]], if route
* message is an RREQ.
* @param targSeqNum Sequence number associated with the route to [[targAddress]], if present in
* the route message.
* @param metric Metric value received in the route message.
* @param timestamp Last time this entry was updated.
*/
private case class RouteMessageEntry(messageType: Int, origAddress: Address,
targAddress: Address, origSeqNum: Int, targSeqNum: Int,
metric: Int, timestamp: DateTime)
private var entries = Set[RouteMessageEntry]()
private def addEntry(msg: Message): Unit = msg.body match {
case rreq: RouteRequest =>
entries += new RouteMessageEntry(RouteRequest.Type, msg.header.origin, msg.header.target,
msg.header.seqNum, rreq.targSeqNum, rreq.originMetric,
DateTime.now)
case rrep: RouteReply =>
entries += new RouteMessageEntry(RouteReply.Type, msg.header.origin, msg.header.target,
msg.header.seqNum, rrep.originSeqNum, rrep.originMetric,
DateTime.now)
}
def isMessageRedundant(msg: Message): Boolean = {
handleTimeouts()
val existingEntry =
entries.find { e =>
val haveEntry = e.messageType == msg.header.protocolType &&
e.origAddress == msg.header.origin && e.targAddress == msg.header.target
val (metric, seqNumComparison) = msg.body match {
case rreq: RouteRequest => (rreq.originMetric, Router.compare(rreq.originSeqNum, e.origSeqNum))
case rrep: RouteReply => (rrep.originMetric, Router.compare(rrep.originSeqNum, e.targSeqNum))
}
val isMetricBetter = e.metric < metric
haveEntry && (seqNumComparison > 0 || (seqNumComparison == 0 && isMetricBetter))
}
if (existingEntry.isDefined)
entries = entries - existingEntry.get
addEntry(msg)
existingEntry.isDefined
}
private def handleTimeouts(): Unit = {
entries = entries.filter { e =>
DateTime.now.isBefore(e.timestamp.plus(MaxSeqnumLifetime))
}
}
}

View file

@ -41,8 +41,9 @@ class CryptoTest extends TestCase {
def testEncryptDecrypt(): Unit = { def testEncryptDecrypt(): Unit = {
MessageTest.messages.foreach{ m => MessageTest.messages.foreach{ m =>
val encrypted = crypto.encryptAndSign(m, Option(crypto.getLocalPublicKey)) val encrypted = crypto.encryptAndSign(m, Option(crypto.getLocalPublicKey))
val decrypted = crypto.verifyAndDecrypt(encrypted, Option(crypto.getLocalPublicKey)) assertTrue(crypto.verify(encrypted, Option(crypto.getLocalPublicKey)))
assertEquals(m.body, decrypted.get.body) val decrypted = crypto.decrypt(encrypted)
assertEquals(m.body, decrypted.body)
assertEquals(m.header, encrypted.header) assertEquals(m.header, encrypted.header)
} }
} }

View file

@ -58,9 +58,10 @@ class MessageTest extends TestCase {
val read = Message.read(new ByteArrayInputStream(bytes)) val read = Message.read(new ByteArrayInputStream(bytes))
assertEquals(encrypted.crypto, read.crypto) assertEquals(encrypted.crypto, read.crypto)
val decrypted = crypto.verifyAndDecrypt(read, Option(crypto.getLocalPublicKey)) assertTrue(crypto.verify(read, Option(crypto.getLocalPublicKey)))
assertEquals(m.header, decrypted.get.header) val decrypted = crypto.decrypt(read)
assertEquals(m.body, decrypted.get.body) assertEquals(m.header, decrypted.header)
assertEquals(m.body, decrypted.body)
} }
} }

View file

@ -1,34 +1,45 @@
package com.nutomic.ensichat.core package com.nutomic.ensichat.core
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.{Date, GregorianCalendar} import java.util.{Date, GregorianCalendar}
import com.nutomic.ensichat.core.body.{Text, UserInfo} import com.nutomic.ensichat.core.body.{Text, UserInfo}
import com.nutomic.ensichat.core.header.ContentHeader import com.nutomic.ensichat.core.header.ContentHeader
import com.nutomic.ensichat.core.util.LocalRoutesInfo
import junit.framework.TestCase import junit.framework.TestCase
import org.junit.Assert._ import org.junit.Assert._
class RouterTest extends TestCase { class RouterTest extends TestCase {
private def neighbors() = Set[Address](AddressTest.a1, AddressTest.a2, AddressTest.a3) private def neighbors() = Set[Address](AddressTest.a1, AddressTest.a2, AddressTest.a4)
private val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1) def testNoRouteFound(): Unit = {
val msg = generateMessage(AddressTest.a2, AddressTest.a3, 1)
val latch = new CountDownLatch(1)
val router = new Router(new LocalRoutesInfo(neighbors),
(_, _) => fail("Message shouldn't be forwarded"), m => {
assertEquals(msg, m)
latch.countDown()
})
router.forwardMessage(msg)
assertTrue(latch.await(1, TimeUnit.SECONDS))
}
/** def testNextHop(): Unit = {
* Messages should be sent to all neighbors. val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1)
*/
def testFlooding(): Unit = {
var sentTo = Set[Address]() var sentTo = Set[Address]()
val router: Router = new Router(neighbors, val router = new Router(new LocalRoutesInfo(neighbors),
(a, m) => { (a, m) => {
sentTo += a sentTo += a
}) }, _ => ())
router.forwardMessage(msg) router.forwardMessage(msg)
assertEquals(neighbors(), sentTo) assertEquals(Set(AddressTest.a4), sentTo)
} }
def testMessageSame(): Unit = { def testMessageSame(): Unit = {
val router: Router = new Router(neighbors, val msg = generateMessage(AddressTest.a1, AddressTest.a4, 1)
val router = new Router(new LocalRoutesInfo(neighbors),
(a, m) => { (a, m) => {
assertEquals(msg.header.origin, m.header.origin) assertEquals(msg.header.origin, m.header.origin)
assertEquals(msg.header.target, m.header.target) assertEquals(msg.header.target, m.header.target)
@ -38,7 +49,7 @@ class RouterTest extends TestCase {
assertEquals(msg.header.hopLimit, m.header.hopLimit) assertEquals(msg.header.hopLimit, m.header.hopLimit)
assertEquals(msg.body, m.body) assertEquals(msg.body, m.body)
assertEquals(msg.crypto, m.crypto) assertEquals(msg.crypto, m.crypto)
}) }, _ => ())
router.forwardMessage(msg) router.forwardMessage(msg)
} }
@ -47,26 +58,32 @@ class RouterTest extends TestCase {
*/ */
def testDifferentSenders(): Unit = { def testDifferentSenders(): Unit = {
var sentTo = Set[Address]() var sentTo = Set[Address]()
val router: Router = new Router(neighbors, (a, m) => sentTo += a) val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => sentTo += a, _ => ())
router.forwardMessage(msg) router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, 1))
assertEquals(neighbors(), sentTo) assertEquals(Set(AddressTest.a4), sentTo)
sentTo = Set[Address]() sentTo = Set[Address]()
router.forwardMessage(generateMessage(AddressTest.a2, AddressTest.a4, 1)) router.forwardMessage(generateMessage(AddressTest.a2, AddressTest.a4, 1))
assertEquals(neighbors(), sentTo) assertEquals(Set(AddressTest.a4), sentTo)
}
def testSeqNumComparison(): Unit = {
Router.compare(1, ContentHeader.SeqNumRange.last)
Router.compare(ContentHeader.SeqNumRange.last / 2, ContentHeader.SeqNumRange.last)
Router.compare(ContentHeader.SeqNumRange.last / 2, 1)
} }
def testDiscardOldIgnores(): Unit = { def testDiscardOldIgnores(): Unit = {
def test(first: Int, second: Int) { def test(first: Int, second: Int) {
var sentTo = Set[Address]() var sentTo = Set[Address]()
val router: Router = new Router(neighbors, (a, m) => sentTo += a) val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => sentTo += a, _ => ())
router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, first)) router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, first))
router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, second)) router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, second))
sentTo = Set[Address]() sentTo = Set[Address]()
router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a3, first)) router.forwardMessage(generateMessage(AddressTest.a1, AddressTest.a4, first))
assertEquals(neighbors(), sentTo) assertEquals(Set(AddressTest.a4), sentTo)
} }
test(1, ContentHeader.SeqNumRange.last) test(1, ContentHeader.SeqNumRange.last)
@ -77,7 +94,7 @@ class RouterTest extends TestCase {
def testHopLimit(): Unit = Range(19, 22).foreach { i => def testHopLimit(): Unit = Range(19, 22).foreach { i =>
val msg = new Message( val msg = new Message(
new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), i), new Text("")) new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), i), new Text(""))
val router: Router = new Router(neighbors, (a, m) => fail()) val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => fail(), _ => ())
router.forwardMessage(msg) router.forwardMessage(msg)
} }

View file

@ -0,0 +1,16 @@
package com.nutomic.ensichat.core.body
import com.nutomic.ensichat.core.AddressTest
import junit.framework.TestCase
import org.junit.Assert._
class RouteErrorTest extends TestCase {
def testWriteRead(): Unit = {
val rerr = new RouteError(AddressTest.a2, 62000)
val bytes = rerr.write
val parsed = RouteError.read(bytes)
assertEquals(rerr, parsed)
}
}

View file

@ -0,0 +1,16 @@
package com.nutomic.ensichat.core.body
import com.nutomic.ensichat.core.AddressTest
import junit.framework.TestCase
import org.junit.Assert._
class RouteReplyTest extends TestCase {
def testWriteRead(): Unit = {
val rrep = new RouteReply(61000, 123)
val bytes = rrep.write
val parsed = RouteReply.read(bytes)
assertEquals(rrep, parsed)
}
}

View file

@ -0,0 +1,16 @@
package com.nutomic.ensichat.core.body
import com.nutomic.ensichat.core.AddressTest
import junit.framework.TestCase
import org.junit.Assert._
class RouteRequestTest extends TestCase {
def testWriteRead(): Unit = {
val rreq = new RouteRequest(AddressTest.a2, 60000, 60001, 60002)
val bytes = rreq.write
val parsed = RouteRequest.read(bytes)
assertEquals(rreq, parsed)
}
}

View file

@ -0,0 +1,45 @@
package com.nutomic.ensichat.core.util
import com.nutomic.ensichat.core.AddressTest
import junit.framework.TestCase
import org.joda.time.{DateTime, DateTimeUtils, Duration}
import org.junit.Assert._
class LocalRoutesInfoTest extends TestCase {
private def connections() = Set(AddressTest.a1, AddressTest.a2)
def testRoute(): Unit = {
val routesInfo = new LocalRoutesInfo(connections)
routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1)
val route = routesInfo.getRoute(AddressTest.a3)
assertEquals(AddressTest.a1, route.get.nextHop)
}
def testBestMetric(): Unit = {
val routesInfo = new LocalRoutesInfo(connections)
routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1)
routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a2, 2)
val route = routesInfo.getRoute(AddressTest.a3)
assertEquals(AddressTest.a1, route.get.nextHop)
}
def testConnectionClosed(): Unit = {
val routesInfo = new LocalRoutesInfo(connections)
routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1)
routesInfo.addRoute(AddressTest.a4, 0, AddressTest.a1, 1)
// Mark the route as active, because only active routes are returned.
routesInfo.getRoute(AddressTest.a3)
val unreachable = routesInfo.connectionClosed(AddressTest.a1)
assertEquals(Set(AddressTest.a3), unreachable)
}
def testTimeout(): Unit = {
DateTimeUtils.setCurrentMillisFixed(new DateTime().getMillis)
val routesInfo = new LocalRoutesInfo(connections)
routesInfo.addRoute(AddressTest.a3, 0, AddressTest.a1, 1)
DateTimeUtils.setCurrentMillisFixed(DateTime.now.plus(Duration.standardSeconds(400)).getMillis)
assertEquals(None, routesInfo.getRoute(AddressTest.a3))
}
}

View file

@ -0,0 +1,79 @@
package com.nutomic.ensichat.core.util
import com.nutomic.ensichat.core.body.{RouteReply, RouteRequest}
import com.nutomic.ensichat.core.header.MessageHeader
import com.nutomic.ensichat.core.{AddressTest, Message}
import junit.framework.TestCase
import org.joda.time.{DateTime, DateTimeUtils, Duration}
import org.junit.Assert._
class RouteMessageInfoTest extends TestCase {
/**
* Test case in which we have an entry with the same type, origin and target.
*/
def testSameMessage(): Unit = {
val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1)
val msg = new Message(header, new RouteRequest(AddressTest.a3, 2, 3, 1))
val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg))
assertTrue(rmi.isMessageRedundant(msg))
}
/**
* Forward a message with a seqnum that is older than the latest.
*/
def testSeqNumOlder(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1)
val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 0, 0, 0))
val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3)
val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 2, 0, 0))
assertTrue(rmi.isMessageRedundant(msg2))
}
/**
* Announce a route with a metric that is worse than the existing one.
*/
def testMetricWorse(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1)
val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 1, 0, 2))
val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2)
val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 1, 0, 4))
assertTrue(rmi.isMessageRedundant(msg2))
}
/**
* Announce route with a better metric.
*/
def testMetricBetter(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1)
val msg1 = new Message(header1, new RouteReply(0, 4))
val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2)
val msg2 = new Message(header2, new RouteReply(0, 2))
assertFalse(rmi.isMessageRedundant(msg2))
}
/**
* Test that entries are removed after [[RouteMessageInfo.MaxSeqnumLifetime]].
*/
def testTimeout(): Unit = {
val rmi = new RouteMessageInfo()
DateTimeUtils.setCurrentMillisFixed(DateTime.now.getMillis)
val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1)
val msg = new Message(header, new RouteRequest(AddressTest.a3, 0, 0, 0))
assertFalse(rmi.isMessageRedundant(msg))
DateTimeUtils.setCurrentMillisFixed(DateTime.now.plus(Duration.standardSeconds(400)).getMillis)
assertFalse(rmi.isMessageRedundant(msg))
}
}

1
integration/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

12
integration/build.gradle Normal file
View file

@ -0,0 +1,12 @@
apply plugin: 'scala'
apply plugin: 'application'
dependencies {
compile 'org.scala-lang:scala-library:2.11.7'
compile 'com.github.scala-incubator.io:scala-io-file_2.11:0.4.3'
compile project(path: ':core')
}
mainClassName = 'com.nutomic.ensichat.integration.Main'
version = "0.2.3"
applicationName = 'ensichat-server'

View file

@ -0,0 +1,84 @@
package com.nutomic.ensichat.integration
import java.io.File
import java.util.concurrent.{LinkedBlockingDeque, LinkedBlockingQueue}
import com.nutomic.ensichat.core.body.{RouteError, RouteRequest, RouteReply}
import com.nutomic.ensichat.core.interfaces.{CallbackInterface, SettingsInterface}
import com.nutomic.ensichat.core.util.Database
import com.nutomic.ensichat.core.{ConnectionHandler, Crypto, Message}
import com.nutomic.ensichat.integration.LocalNode._
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scalax.file.Path
object LocalNode {
private final val StartingPort = 21000
object EventType extends Enumeration {
type EventType = Value
val MessageReceived, ConnectionsChanged, ContactsUpdated = Value
}
class FifoStream[A]() {
private val queue = new LinkedBlockingQueue[Option[A]]()
def toStream: Stream[A] = queue.take match {
case Some(a) => Stream.cons(a, toStream)
case None => Stream.empty
}
def close() = queue add None
def enqueue(a: A) = queue.put(Option(a))
}
}
/**
* Runs an ensichat node on localhost.
*
* Received messages can be accessed through [[eventQueue]].
*
* @param index Number of this node. The server port is opened on port [[StartingPort]] + index.
* @param configFolder Folder where keys and configuration should be stored.
*/
class LocalNode(val index: Int, configFolder: File) extends CallbackInterface {
import com.nutomic.ensichat.integration.LocalNode.EventType._
private val databaseFile = new File(configFolder, "database")
private val keyFolder = new File(configFolder, "keys")
private val database = new Database(databaseFile, this)
private val settings = new SettingsInterface {
private var values = Map[String, Any]()
override def get[T](key: String, default: T): T = values.get(key).map(_.asInstanceOf[T]).getOrElse(default)
override def put[T](key: String, value: T): Unit = values += (key -> value.asInstanceOf[Any])
}
val crypto = new Crypto(settings, keyFolder)
val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 0, port)
val eventQueue = new FifoStream[(EventType.EventType, Option[Message])]()
configFolder.mkdirs()
keyFolder.mkdirs()
settings.put(SettingsInterface.KeyAddresses, "")
Await.result(connectionHandler.start(), Duration.Inf)
def port = StartingPort + index
def stop(): Unit = {
connectionHandler.stop()
Path(configFolder).deleteRecursively()
}
def onMessageReceived(msg: Message): Unit = {
eventQueue.enqueue((EventType.MessageReceived, Option(msg)))
}
def onConnectionsChanged(): Unit =
eventQueue.enqueue((EventType.ConnectionsChanged, None))
def onContactsUpdated(): Unit =
eventQueue.enqueue((EventType.ContactsUpdated, None))
}

View file

@ -0,0 +1,135 @@
package com.nutomic.ensichat.integration
import java.io.File
import java.util.concurrent.{CountDownLatch, TimeUnit}
import com.nutomic.ensichat.core.Crypto
import com.nutomic.ensichat.core.body.Text
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}
import scala.util.Try
/**
* Creates some local nodes, connects them and sends messages between them.
*
* If the test runs slow or fails, changing [[Crypto.PublicKeySize]] to 512 should help.
*/
object Main extends App {
val nodes = createMesh()
System.out.println("\n\nAll nodes connected!\n\n")
sendMessages(nodes)
System.out.println("\n\nAll messages sent!\n\n")
// Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8.
nodes(1).connectionHandler.stop()
System.out.println("node 1 stopped")
sendMessages(nodes)
/**
* Creates a new mesh with a predefined layout.
*
* Graphical representation:
* 8 7
* / \
* 0134
* \ / | |
* 2 56
*
* @return List of [[LocalNode]]s, ordered from 0 to 7.
*/
private def createMesh(): Seq[LocalNode] = {
val nodes = Await.result(Future.sequence(0.to(8).map(createNode)), Duration.Inf)
sys.addShutdownHook(nodes.foreach(_.stop()))
connectNodes(nodes(0), nodes(1))
connectNodes(nodes(0), nodes(2))
connectNodes(nodes(1), nodes(2))
connectNodes(nodes(1), nodes(3))
connectNodes(nodes(3), nodes(4))
connectNodes(nodes(3), nodes(5))
connectNodes(nodes(4), nodes(6))
connectNodes(nodes(5), nodes(6))
connectNodes(nodes(3), nodes(7))
connectNodes(nodes(0), nodes(8))
connectNodes(nodes(7), nodes(8))
nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}"))
nodes
}
private def createNode(index: Int): Future[LocalNode] = {
val configFolder = new File(s"build/node$index/")
assert(!configFolder.exists(), s"stale config exists in $configFolder")
Future(new LocalNode(index, configFolder))
}
private def connectNodes(first: LocalNode, second: LocalNode): Unit = {
first.connectionHandler.connect(s"localhost:${second.port}")
first.eventQueue.toStream.find(_._1 == LocalNode.EventType.ConnectionsChanged)
second.eventQueue.toStream.find(_._1 == LocalNode.EventType.ConnectionsChanged)
val firstAddress = first.crypto.localAddress
val secondAddress = second.crypto.localAddress
val firstConnections = first.connectionHandler.connections()
val secondConnections = second.connectionHandler.connections()
assert(firstConnections.contains(secondAddress),
s"${first.index} is not connected to ${second.index}")
assert(secondConnections.contains(firstAddress),
s"${second.index} is not connected to ${second.index}")
System.out.println(s"${first.index} and ${second.index} connected")
}
private def sendMessages(nodes: Seq[LocalNode]): Unit = {
sendMessage(nodes(0), nodes(2))
sendMessage(nodes(2), nodes(0))
sendMessage(nodes(4), nodes(3))
sendMessage(nodes(3), nodes(5))
sendMessage(nodes(4), nodes(6))
sendMessage(nodes(2), nodes(3))
sendMessage(nodes(0), nodes(3))
sendMessage(nodes(3), nodes(6))
sendMessage(nodes(3), nodes(2))
}
private def sendMessage(from: LocalNode, to: LocalNode): Unit = {
addKey(to.crypto, from.crypto)
addKey(from.crypto, to.crypto)
System.out.println(s"sendMessage(${from.index}, ${to.index})")
val text = s"${from.index} to ${to.index}"
from.connectionHandler.sendTo(to.crypto.localAddress, new Text(text))
val latch = new CountDownLatch(1)
Future {
val exists =
to.eventQueue.toStream.exists { event =>
if (event._1 != LocalNode.EventType.MessageReceived)
false
else {
event._2.get.body match {
case t: Text => t.text == text
case _ => false
}
}
}
assert(exists, s"message from ${from.index} did not arrive at ${to.index}")
latch.countDown()
}
assert(latch.await(1000, TimeUnit.MILLISECONDS))
}
private def addKey(addTo: Crypto, addFrom: Crypto): Unit = {
if (Try(addTo.getPublicKey(addFrom.localAddress)).isFailure)
addTo.addPublicKey(addFrom.localAddress, addFrom.getLocalPublicKey)
}
}

View file

@ -1 +1 @@
include ':android', ':core', ':server' include ':android', ':core', ':server', ':integration'