Merge branch 'relay-servers'

This commit is contained in:
Felix Ableitner 2016-06-24 13:35:42 +02:00
commit 5310c34218
19 changed files with 308 additions and 113 deletions

View file

@ -85,7 +85,7 @@ version, type and ID, followed by the length of the message.
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version | Protocol-Type | Hop Limit | Hop Count | | Version | Protocol-Type | Tokens | Hop Limit |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length | | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
@ -111,8 +111,8 @@ where such a packet came from MAY be closed.
Protocol-Type is one of those specified in section Protocol Messages, Protocol-Type is one of those specified in section Protocol Messages,
or 255 for Content Messages. or 255 for Content Messages.
Hop Limit SHOULD be set to `20` on message creation, and Tokens is the number of times this message should be copied to
MUST NOT be changed by a forwarding node. different relays.
Hop Count specifies the number of nodes a message may pass. When Hop Count specifies the number of nodes a message may pass. When
creating a package, it is initialized to 0. Whenever a node forwards creating a package, it is initialized to 0. Whenever a node forwards

View file

@ -12,6 +12,7 @@ import com.nutomic.ensichat.core.body.ConnectionInfo
import com.nutomic.ensichat.core.interfaces.{SettingsInterface, TransmissionInterface} import com.nutomic.ensichat.core.interfaces.{SettingsInterface, TransmissionInterface}
import com.nutomic.ensichat.core.{Address, ConnectionHandler, Message} import com.nutomic.ensichat.core.{Address, ConnectionHandler, Message}
import com.nutomic.ensichat.service.ChatService import com.nutomic.ensichat.service.ChatService
import org.joda.time.{DateTime, Duration}
import scala.collection.immutable.HashMap import scala.collection.immutable.HashMap
@ -169,12 +170,13 @@ class BluetoothInterface(context: Context, mainHandler: Handler,
/** /**
* Removes device from active connections. * Removes device from active connections.
*/ */
def onConnectionClosed(device: Device, socket: BluetoothSocket): Unit = { def onConnectionClosed(connectionOpened: DateTime, deviceId: Device.ID): Unit = {
val address = getAddressForDevice(device.id) val address = getAddressForDevice(deviceId)
devices -= device.id devices -= deviceId
connections -= device.id connections -= deviceId
addressDeviceMap = addressDeviceMap.filterNot(_._2 == device.id) addressDeviceMap = addressDeviceMap.filterNot(_._2 == deviceId)
connectionHandler.onConnectionClosed(address) val connectionDuration = new Duration(connectionOpened, DateTime.now)
connectionHandler.onConnectionClosed(address, connectionDuration)
} }
/** /**

View file

@ -9,6 +9,7 @@ import com.nutomic.ensichat.core.Message.ReadMessageException
import com.nutomic.ensichat.core.body.ConnectionInfo import com.nutomic.ensichat.core.body.ConnectionInfo
import com.nutomic.ensichat.core.header.MessageHeader import com.nutomic.ensichat.core.header.MessageHeader
import com.nutomic.ensichat.core.{Address, Crypto, Message} import com.nutomic.ensichat.core.{Address, Crypto, Message}
import org.joda.time.DateTime
/** /**
* Transfers data between connnected devices. * Transfers data between connnected devices.
@ -17,8 +18,11 @@ import com.nutomic.ensichat.core.{Address, Crypto, Message}
* @param socket An open socket to the given device. * @param socket An open socket to the given device.
* @param onReceive Called when a message was received from the other device. * @param onReceive Called when a message was received from the other device.
*/ */
class BluetoothTransferThread(context: Context, device: Device, socket: BluetoothSocket, handler: BluetoothInterface, class BluetoothTransferThread(context: Context, device: Device, socket: BluetoothSocket,
crypto: Crypto, onReceive: (Message, Device.ID) => Unit) extends Thread { handler: BluetoothInterface, crypto: Crypto,
onReceive: (Message, Device.ID) => Unit) extends Thread {
private val connectionOpened = DateTime.now
private val Tag = "TransferThread" private val Tag = "TransferThread"
@ -61,7 +65,7 @@ class BluetoothTransferThread(context: Context, device: Device, socket: Bluetoot
new IntentFilter(BluetoothDevice.ACTION_ACL_DISCONNECTED)) new IntentFilter(BluetoothDevice.ACTION_ACL_DISCONNECTED))
send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type, send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type,
Address.Null, Address.Null, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) Address.Null, Address.Null, 0, 0), new ConnectionInfo(crypto.getLocalPublicKey))))
while (socket.isConnected) { while (socket.isConnected) {
try { try {
@ -102,7 +106,7 @@ class BluetoothTransferThread(context: Context, device: Device, socket: Bluetoot
} catch { } catch {
case e: IOException => Log.e(Tag, "Failed to close socket", e); case e: IOException => Log.e(Tag, "Failed to close socket", e);
} finally { } finally {
handler.onConnectionClosed(new Device(device.btDevice.get, false), null) handler.onConnectionClosed(connectionOpened, device.id)
} }
} }

View file

@ -1,15 +1,15 @@
package com.nutomic.ensichat.core package com.nutomic.ensichat.core
import java.security.InvalidKeyException import java.security.InvalidKeyException
import java.util.{TimerTask, Timer, Date} import java.util.Date
import com.nutomic.ensichat.core.body._ import com.nutomic.ensichat.core.body._
import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader} import com.nutomic.ensichat.core.header.{AbstractHeader, ContentHeader, MessageHeader}
import com.nutomic.ensichat.core.interfaces._ import com.nutomic.ensichat.core.interfaces._
import com.nutomic.ensichat.core.internet.InternetInterface import com.nutomic.ensichat.core.internet.InternetInterface
import com.nutomic.ensichat.core.util._ import com.nutomic.ensichat.core.util._
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
import org.joda.time.{DateTime, Duration} import org.joda.time.Duration
import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future import scala.concurrent.Future
@ -27,8 +27,6 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private val CheckMessageRetryInterval = Duration.standardMinutes(1)
private var transmissionInterfaces = Set[TransmissionInterface]() private var transmissionInterfaces = Set[TransmissionInterface]()
private lazy val seqNumGenerator = new SeqNumGenerator(settings) private lazy val seqNumGenerator = new SeqNumGenerator(settings)
@ -83,12 +81,13 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
FutureHelper { FutureHelper {
val messageId = settings.get("message_id", 0L) val messageId = settings.get("message_id", 0L)
val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(), val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(),
body.contentType, Some(messageId), Some(new Date())) body.contentType, Some(messageId), Some(new Date()), AbstractHeader.InitialForwardingTokens)
settings.put("message_id", messageId + 1) settings.put("message_id", messageId + 1)
val msg = new Message(header, body) val msg = new Message(header, body)
val encrypted = crypto.encryptAndSign(msg) val encrypted = crypto.encryptAndSign(msg)
router.forwardMessage(encrypted) router.forwardMessage(encrypted)
forwardMessageToRelays(encrypted)
onNewMessage(msg) onNewMessage(msg)
} }
} }
@ -98,7 +97,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
val seqNum = seqNumGenerator.next() val seqNum = seqNumGenerator.next()
val targetSeqNum = localRoutesInfo.getRoute(target).map(_.seqNum).getOrElse(-1) val targetSeqNum = localRoutesInfo.getRoute(target).map(_.seqNum).getOrElse(-1)
val body = new RouteRequest(target, seqNum, targetSeqNum, 0) val body = new RouteRequest(target, seqNum, targetSeqNum, 0)
val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum) val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum, 0)
val signed = crypto.sign(new Message(header, body)) val signed = crypto.sign(new Message(header, body))
logger.trace(s"sending new $signed") logger.trace(s"sending new $signed")
@ -108,7 +107,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
private def replyRoute(target: Address, replyTo: Address): Unit = { private def replyRoute(target: Address, replyTo: Address): Unit = {
val seqNum = seqNumGenerator.next() val seqNum = seqNumGenerator.next()
val body = new RouteReply(seqNum, 0) val body = new RouteReply(seqNum, 0)
val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum) val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum, 0)
val signed = crypto.sign(new Message(header, body)) val signed = crypto.sign(new Message(header, body))
logger.trace(s"sending new $signed") logger.trace(s"sending new $signed")
@ -118,7 +117,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
private def routeError(address: Address, packetSource: Option[Address]): Unit = { private def routeError(address: Address, packetSource: Option[Address]): Unit = {
val destination = packetSource.getOrElse(Address.Broadcast) val destination = packetSource.getOrElse(Address.Broadcast)
val header = new MessageHeader(RouteError.Type, crypto.localAddress, destination, val header = new MessageHeader(RouteError.Type, crypto.localAddress, destination,
seqNumGenerator.next()) seqNumGenerator.next(), 0)
val seqNum = localRoutesInfo.getRoute(address).map(_.seqNum).getOrElse(-1) val seqNum = localRoutesInfo.getRoute(address).map(_.seqNum).getOrElse(-1)
val body = new RouteError(address, seqNum) val body = new RouteError(address, seqNum)
@ -211,6 +210,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
if (msg.header.target != crypto.localAddress) { if (msg.header.target != crypto.localAddress) {
router.forwardMessage(msg) router.forwardMessage(msg)
forwardMessageToRelays(msg)
return return
} }
@ -234,6 +234,20 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
onNewMessage(plainMsg) onNewMessage(plainMsg)
} }
private def forwardMessageToRelays(message: Message): Unit = {
var tokens = message.header.tokens
val relays = database.pickLongestConnectionDevice(connections())
var index = 0
while (tokens > 1) {
val forwardTokens = tokens / 2
val headerCopy = message.header.asInstanceOf[ContentHeader].copy(tokens = forwardTokens)
router.forwardMessage(message.copy(header = headerCopy), relays.lift(index))
tokens -= forwardTokens
database.updateMessageForwardingTokens(message, tokens)
index += 1
}
}
/** /**
* Tries to send messages in [[MessageBuffer]] again, after we acquired a new route. * Tries to send messages in [[MessageBuffer]] again, after we acquired a new route.
*/ */
@ -244,11 +258,8 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
} }
private def noRouteFound(message: Message): Unit = { private def noRouteFound(message: Message): Unit = {
if (message.header.origin == crypto.localAddress) { messageBuffer.addMessage(message)
messageBuffer.addMessage(message) requestRoute(message.header.target)
requestRoute(message.header.target)
} else
routeError(message.header.target, Option(message.header.origin))
} }
/** /**
@ -317,13 +328,23 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
settings.get(SettingsInterface.KeyUserStatus, ""))) settings.get(SettingsInterface.KeyUserStatus, "")))
callbacks.onConnectionsChanged() callbacks.onConnectionsChanged()
resendMissingRouteMessages() resendMissingRouteMessages()
messageBuffer.getAllMessages
.filter(_.header.tokens > 1)
.foreach(forwardMessageToRelays)
true true
} }
def onConnectionClosed(address: Address): Unit = { /**
* Called by [[TransmissionInterface]] when a connection is closed.
*
* @param address The address of the connected device.
* @param duration The time that we were connected to the device.
*/
def onConnectionClosed(address: Address, duration: Duration): Unit = {
localRoutesInfo.connectionClosed(address) localRoutesInfo.connectionClosed(address)
.foreach(routeError(_, None)) .foreach(routeError(_, None))
callbacks.onConnectionsChanged() callbacks.onConnectionsChanged()
database.insertOrUpdateKnownDevice(address, duration)
} }
def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections) def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections)

View file

@ -7,6 +7,8 @@ import com.nutomic.ensichat.core.util.LocalRoutesInfo
object Router extends Comparator[Int] { object Router extends Comparator[Int] {
private val HopLimit = 20
/** /**
* Compares which sequence number is newer. * Compares which sequence number is newer.
* *
@ -50,7 +52,7 @@ private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message)
* true. * true.
*/ */
def forwardMessage(msg: Message, nextHopOption: Option[Address] = None): Unit = { def forwardMessage(msg: Message, nextHopOption: Option[Address] = None): Unit = {
if (msg.header.hopCount + 1 >= msg.header.hopLimit) if (msg.header.hopCount + 1 >= Router.HopLimit)
return return
val nextHop = nextHopOption.getOrElse(msg.header.target) val nextHop = nextHopOption.getOrElse(msg.header.target)
@ -79,10 +81,8 @@ private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message)
*/ */
private def incHopCount(msg: Message): Message = { private def incHopCount(msg: Message): Message = {
val updatedHeader = msg.header match { val updatedHeader = msg.header match {
case ch: ContentHeader => new ContentHeader(ch.origin, ch.target, ch.seqNum, ch.contentType, case ch: ContentHeader => ch.copy(hopCount = ch.hopCount + 1)
ch.messageId, ch.time, ch.hopCount + 1, ch.hopLimit) case mh: MessageHeader => mh.copy(hopCount = mh.hopCount + 1)
case mh: MessageHeader => new MessageHeader(mh.protocolType, mh.origin, mh.target, mh.seqNum,
mh.hopCount + 1, mh.hopLimit)
} }
new Message(updatedHeader, msg.crypto, msg.body) new Message(updatedHeader, msg.crypto, msg.body)
} }

View file

@ -8,7 +8,7 @@ import com.nutomic.ensichat.core.util.BufferUtils
object AbstractHeader { object AbstractHeader {
val DefaultHopLimit = 20 val InitialForwardingTokens = 3
val Version = 0 val Version = 0
@ -25,7 +25,7 @@ object AbstractHeader {
trait AbstractHeader { trait AbstractHeader {
def protocolType: Int def protocolType: Int
def hopLimit: Int def tokens: Int
def hopCount: Int def hopCount: Int
def origin: Address def origin: Address
def target: Address def target: Address
@ -41,7 +41,7 @@ trait AbstractHeader {
BufferUtils.putUnsignedByte(b, AbstractHeader.Version) BufferUtils.putUnsignedByte(b, AbstractHeader.Version)
BufferUtils.putUnsignedByte(b, protocolType) BufferUtils.putUnsignedByte(b, protocolType)
BufferUtils.putUnsignedByte(b, hopLimit) BufferUtils.putUnsignedByte(b, tokens)
BufferUtils.putUnsignedByte(b, hopCount) BufferUtils.putUnsignedByte(b, hopCount)
BufferUtils.putUnsignedInt(b, length + contentLength) BufferUtils.putUnsignedInt(b, length + contentLength)
@ -63,7 +63,7 @@ trait AbstractHeader {
override def equals(a: Any): Boolean = a match { override def equals(a: Any): Boolean = a match {
case o: AbstractHeader => case o: AbstractHeader =>
protocolType == o.protocolType && protocolType == o.protocolType &&
hopLimit == o.hopLimit && tokens == o.tokens &&
hopCount == o.hopCount && hopCount == o.hopCount &&
origin == o.origin && origin == o.origin &&
target == o.target && target == o.target &&

View file

@ -25,7 +25,7 @@ object ContentHeader {
val time = BufferUtils.getUnsignedInt(b) val time = BufferUtils.getUnsignedInt(b)
val ch = new ContentHeader(mh.origin, mh.target, mh.seqNum, contentType, Some(messageId), val ch = new ContentHeader(mh.origin, mh.target, mh.seqNum, contentType, Some(messageId),
Some(new Date(time * 1000)), mh.hopCount) Some(new Date(time * 1000)), mh.tokens, mh.hopCount)
val remaining = new Array[Byte](b.remaining()) val remaining = new Array[Byte](b.remaining())
b.get(remaining, 0, b.remaining()) b.get(remaining, 0, b.remaining())
@ -45,8 +45,8 @@ final case class ContentHeader(override val origin: Address,
contentType: Int, contentType: Int,
override val messageId: Some[Long], override val messageId: Some[Long],
override val time: Some[Date], override val time: Some[Date],
override val hopCount: Int = 0, override val tokens: Int,
override val hopLimit: Int = AbstractHeader.DefaultHopLimit) override val hopCount: Int = 0)
extends AbstractHeader { extends AbstractHeader {
override val protocolType = ContentHeader.ContentMessageType override val protocolType = ContentHeader.ContentMessageType

View file

@ -23,7 +23,7 @@ object MessageHeader {
if (version != AbstractHeader.Version) if (version != AbstractHeader.Version)
throw new ReadMessageException("Failed to parse message with unsupported version " + version) throw new ReadMessageException("Failed to parse message with unsupported version " + version)
val protocolType = BufferUtils.getUnsignedByte(b) val protocolType = BufferUtils.getUnsignedByte(b)
val hopLimit = BufferUtils.getUnsignedByte(b) val tokens = BufferUtils.getUnsignedByte(b)
val hopCount = BufferUtils.getUnsignedByte(b) val hopCount = BufferUtils.getUnsignedByte(b)
val length = BufferUtils.getUnsignedInt(b) val length = BufferUtils.getUnsignedInt(b)
@ -34,7 +34,7 @@ object MessageHeader {
val seqNum = BufferUtils.getUnsignedShort(b) val seqNum = BufferUtils.getUnsignedShort(b)
(new MessageHeader(protocolType, origin, target, seqNum, hopCount, hopLimit), length.toInt) (new MessageHeader(protocolType, origin, target, seqNum, tokens, hopCount), length.toInt)
} }
} }
@ -48,8 +48,8 @@ final case class MessageHeader(override val protocolType: Int,
override val origin: Address, override val origin: Address,
override val target: Address, override val target: Address,
override val seqNum: Int, override val seqNum: Int,
override val hopCount: Int = 0, override val tokens: Int,
override val hopLimit: Int = AbstractHeader.DefaultHopLimit) override val hopCount: Int = 0)
extends AbstractHeader { extends AbstractHeader {
def length: Int = MessageHeader.Length def length: Int = MessageHeader.Length

View file

@ -8,6 +8,7 @@ import com.nutomic.ensichat.core.body.ConnectionInfo
import com.nutomic.ensichat.core.header.MessageHeader import com.nutomic.ensichat.core.header.MessageHeader
import com.nutomic.ensichat.core.{Address, Crypto, Message} import com.nutomic.ensichat.core.{Address, Crypto, Message}
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
import org.joda.time.DateTime
/** /**
* Encapsulates an active connection to another node. * Encapsulates an active connection to another node.
@ -17,6 +18,8 @@ private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto,
onReceive: (Message, InternetConnectionThread) => Unit) onReceive: (Message, InternetConnectionThread) => Unit)
extends Thread { extends Thread {
val connectionOpened = DateTime.now
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private val inStream: InputStream = private val inStream: InputStream =
@ -47,7 +50,7 @@ private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto,
logger.info("Connection opened to " + socket.getInetAddress) logger.info("Connection opened to " + socket.getInetAddress)
send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type, send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type,
Address.Null, Address.Null, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) Address.Null, Address.Null, 0, 0), new ConnectionInfo(crypto.getLocalPublicKey))))
try { try {
socket.setKeepAlive(true) socket.setKeepAlive(true)

View file

@ -7,6 +7,7 @@ import com.nutomic.ensichat.core.interfaces.{SettingsInterface, TransmissionInte
import com.nutomic.ensichat.core.util.FutureHelper import com.nutomic.ensichat.core.util.FutureHelper
import com.nutomic.ensichat.core.{Address, ConnectionHandler, Crypto, Message} import com.nutomic.ensichat.core.{Address, ConnectionHandler, Crypto, Message}
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
import org.joda.time.{DateTime, Duration}
import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future import scala.concurrent.Future
@ -102,7 +103,8 @@ private[core] class InternetInterface(connectionHandler: ConnectionHandler, cryp
logger.trace("Connection closed to " + ad) logger.trace("Connection closed to " + ad)
connections -= connectionThread connections -= connectionThread
addressDeviceMap -= ad addressDeviceMap -= ad
connectionHandler.onConnectionClosed(ad) val connectionDuration = new Duration(connectionThread.connectionOpened, DateTime.now)
connectionHandler.onConnectionClosed(ad, connectionDuration)
} }
} }

View file

@ -1,13 +1,15 @@
package com.nutomic.ensichat.core.util package com.nutomic.ensichat.core.util
import java.io.File import java.io.File
import java.sql.DriverManager
import java.util.Date import java.util.Date
import com.nutomic.ensichat.core.body.Text import com.nutomic.ensichat.core.body.Text
import com.nutomic.ensichat.core.header.ContentHeader import com.nutomic.ensichat.core.header.ContentHeader
import com.nutomic.ensichat.core.interfaces.CallbackInterface import com.nutomic.ensichat.core.interfaces.{CallbackInterface, SettingsInterface}
import com.nutomic.ensichat.core.{Address, Message, User} import com.nutomic.ensichat.core.{Address, Message, User}
import com.typesafe.scalalogging.Logger import com.typesafe.scalalogging.Logger
import org.joda.time
import slick.driver.H2Driver.api._ import slick.driver.H2Driver.api._
import scala.concurrent.Await import scala.concurrent.Await
@ -19,10 +21,15 @@ import scala.concurrent.duration.Duration
* *
* @param path The database file. * @param path The database file.
*/ */
class Database(path: File, callbackInterface: CallbackInterface) { class Database(path: File, settings: SettingsInterface, callbackInterface: CallbackInterface) {
private val logger = Logger(this.getClass) private val logger = Logger(this.getClass)
private val DatabaseVersionKey = "database_version"
private val DatabaseVersion = 2
private val DatabasePath = "jdbc:h2:" + path.getAbsolutePath + ";DATABASE_TO_UPPER=false"
private class Messages(tag: Tag) extends Table[Message](tag, "MESSAGES") { private class Messages(tag: Tag) extends Table[Message](tag, "MESSAGES") {
def id = primaryKey("id", (origin, messageId)) def id = primaryKey("id", (origin, messageId))
def origin = column[String]("origin") def origin = column[String]("origin")
@ -30,20 +37,23 @@ class Database(path: File, callbackInterface: CallbackInterface) {
def messageId = column[Long]("message_id") def messageId = column[Long]("message_id")
def text = column[String]("text") def text = column[String]("text")
def date = column[Long]("date") def date = column[Long]("date")
def * = (origin, target, messageId, text, date).<> [Message, (String, String, Long, String, Long)]( { tuple => def tokens = column[Int]("tokens")
val header = new ContentHeader(new Address(tuple._1), def * = (origin, target, messageId, text, date, tokens) <> [Message, (String, String, Long, String, Long, Int)]( {
new Address(tuple._2), tuple =>
-1, val header = new ContentHeader(new Address(tuple._1),
Text.Type, new Address(tuple._2),
Some(tuple._3), -1,
Some(new Date(tuple._5))) Text.Type,
val body = new Text(tuple._4) Some(tuple._3),
new Message(header, body) Some(new Date(tuple._5)),
}, { message => tuple._6)
val body = new Text(tuple._4)
new Message(header, body)
}, message =>
Option((message.header.origin.toString(), message.header.target.toString(), Option((message.header.origin.toString(), message.header.target.toString(),
message.header.messageId.get, message.body.asInstanceOf[Text].text, message.header.messageId.get, message.body.asInstanceOf[Text].text,
message.header.time.get.getTime)) message.header.time.get.getTime, message.header.tokens))
}) )
} }
private val messages = TableQuery[Messages] private val messages = TableQuery[Messages]
@ -51,12 +61,21 @@ class Database(path: File, callbackInterface: CallbackInterface) {
def address = column[String]("address", O.PrimaryKey) def address = column[String]("address", O.PrimaryKey)
def name = column[String]("name") def name = column[String]("name")
def status = column[String]("status") def status = column[String]("status")
def wrappedAddress = address.<> [Address, String](new Address(_), a => Option(a.toString())) def wrappedAddress = address <> [Address, String](new Address(_), a => Option(a.toString))
def * = (wrappedAddress, name, status) <> (User.tupled, User.unapply) def * = (wrappedAddress, name, status) <> (User.tupled, User.unapply)
} }
private val contacts = TableQuery[Contacts] private val contacts = TableQuery[Contacts]
private val db = Database.forURL("jdbc:h2:" + path.getAbsolutePath, driver = "org.h2.Driver") private class KnownDevices(tag: Tag) extends Table[(Address, time.Duration)](tag, "KNOWN_DEVICES") {
def address = column[String]("address", O.PrimaryKey)
def totalConnectionSeconds = column[Long]("total_connection_seconds")
def * = (address, totalConnectionSeconds) <> [(Address, time.Duration), (String, Long)](
tuple => (new Address(tuple._1), time.Duration.standardSeconds(tuple._2)),
tuple => Option((tuple._1.toString, tuple._2.getStandardSeconds)))
}
private val knownDevices = TableQuery[KnownDevices]
private val db = Database.forURL(DatabasePath, driver = "org.h2.Driver")
// Create tables if database doesn't exist. // Create tables if database doesn't exist.
{ {
@ -64,7 +83,25 @@ class Database(path: File, callbackInterface: CallbackInterface) {
val dbFile = new File(path.getAbsolutePath + ".mv.db") val dbFile = new File(path.getAbsolutePath + ".mv.db")
if (!dbFile.exists()) { if (!dbFile.exists()) {
logger.info("Database does not exist, creating tables") logger.info("Database does not exist, creating tables")
Await.result(db.run((messages.schema ++ contacts.schema).create), Duration.Inf) val query = (messages.schema ++ contacts.schema ++ knownDevices.schema).create
Await.result(db.run(query), Duration.Inf)
settings.put(DatabaseVersionKey, DatabaseVersion)
}
}
// Apparently, slick doesn't support ALTER TABLE, so we have to write raw SQL for this...
{
val oldVersion = settings.get(DatabaseVersionKey, 0)
if (oldVersion != DatabaseVersion) {
logger.info(s"Upgrading database from version $oldVersion to $DatabaseVersion")
val connection = DriverManager.getConnection(DatabasePath)
if (oldVersion <= 2) {
connection.createStatement().executeUpdate("ALTER TABLE MESSAGES ADD COLUMN (tokens INT);")
connection.commit()
Await.result(db.run(knownDevices.schema.create), Duration.Inf)
}
connection.close()
settings.put(DatabaseVersionKey, DatabaseVersion)
} }
} }
@ -122,4 +159,30 @@ class Database(path: File, callbackInterface: CallbackInterface) {
callbackInterface.onContactsUpdated() callbackInterface.onContactsUpdated()
} }
def insertOrUpdateKnownDevice(address: Address, connectionTime: time.Duration): Unit = {
val query = knownDevices.insertOrUpdate((address, connectionTime))
Await.result(db.run(query), Duration.Inf)
}
/**
* Returns neighbors sorted by connection time, according to [[KnownDevices]].
*/
def pickLongestConnectionDevice(connections: Set[Address]): List[Address] = {
val map = Await.result(db.run(knownDevices.result), Duration.Inf).toMap
connections
.toList
.sortBy(map(_).getMillis)
.reverse
}
def updateMessageForwardingTokens(message: Message, tokens: Int): Unit = {
val query = messages.filter { c =>
c.origin === message.header.origin.toString &&
c.messageId === message.header.messageId
}
.map(_.tokens)
.update(tokens)
Await.result(db.run(query), Duration.Inf)
}
} }

View file

@ -82,6 +82,8 @@ class MessageBuffer(retryMessageSending: (Address) => Unit) {
ret.map(_.message) ret.map(_.message)
} }
def getAllMessages: Set[Message] = values.map(_.message)
private def handleTimeouts(): Unit = { private def handleTimeouts(): Unit = {
values = values.filter { e => values = values.filter { e =>
e.retryCount < MaxRetryCount e.retryCount < MaxRetryCount

View file

@ -40,7 +40,7 @@ class MessageTest extends TestCase {
} }
def testSerializeSigned(): Unit = { def testSerializeSigned(): Unit = {
val header = new MessageHeader(ConnectionInfo.Type, AddressTest.a4, AddressTest.a2, 0) val header = new MessageHeader(ConnectionInfo.Type, AddressTest.a4, AddressTest.a2, 0, 3)
val m = new Message(header, ConnectionInfoTest.generateCi()) val m = new Message(header, ConnectionInfoTest.generateCi())
val signed = crypto.sign(m) val signed = crypto.sign(m)

View file

@ -46,7 +46,7 @@ class RouterTest extends TestCase {
assertEquals(msg.header.seqNum, m.header.seqNum) assertEquals(msg.header.seqNum, m.header.seqNum)
assertEquals(msg.header.protocolType, m.header.protocolType) assertEquals(msg.header.protocolType, m.header.protocolType)
assertEquals(msg.header.hopCount + 1, m.header.hopCount) assertEquals(msg.header.hopCount + 1, m.header.hopCount)
assertEquals(msg.header.hopLimit, m.header.hopLimit) assertEquals(msg.header.tokens, m.header.tokens)
assertEquals(msg.body, m.body) assertEquals(msg.body, m.body)
assertEquals(msg.crypto, m.crypto) assertEquals(msg.crypto, m.crypto)
}, _ => ()) }, _ => ())
@ -93,14 +93,14 @@ class RouterTest extends TestCase {
def testHopLimit(): Unit = Range(19, 22).foreach { i => def testHopLimit(): Unit = Range(19, 22).foreach { i =>
val msg = new Message( val msg = new Message(
new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), i), new Text("")) new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), 3, i), new Text(""))
val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => fail(), _ => ()) val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => fail(), _ => ())
router.forwardMessage(msg) router.forwardMessage(msg)
} }
private def generateMessage(sender: Address, receiver: Address, seqNum: Int): Message = { private def generateMessage(sender: Address, receiver: Address, seqNum: Int): Message = {
val header = new ContentHeader(sender, receiver, seqNum, UserInfo.Type, Some(5), val header = new ContentHeader(sender, receiver, seqNum, UserInfo.Type, Some(5),
Some(new GregorianCalendar(2014, 6, 10).getTime)) Some(new GregorianCalendar(2014, 6, 10).getTime), 3)
new Message(header, new UserInfo("", "")) new Message(header, new UserInfo("", ""))
} }

View file

@ -11,9 +11,9 @@ object MessageHeaderTest {
0) 0)
val h2 = new MessageHeader(ContentHeader.ContentMessageType, Address.Null, Address.Broadcast, val h2 = new MessageHeader(ContentHeader.ContentMessageType, Address.Null, Address.Broadcast,
ContentHeader.SeqNumRange.last, 0xff) ContentHeader.SeqNumRange.last, 0xff, 3)
val h3 = new MessageHeader(ContentHeader.ContentMessageType, Address.Broadcast, Address.Null, 0) val h3 = new MessageHeader(ContentHeader.ContentMessageType, Address.Broadcast, Address.Null, 0, 3)
val headers = Set(h1, h2, h3) val headers = Set(h1, h2, h3)

View file

@ -13,7 +13,7 @@ class RouteMessageInfoTest extends TestCase {
* Test case in which we have an entry with the same type, origin and target. * Test case in which we have an entry with the same type, origin and target.
*/ */
def testSameMessage(): Unit = { def testSameMessage(): Unit = {
val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0)
val msg = new Message(header, new RouteRequest(AddressTest.a3, 2, 3, 1)) val msg = new Message(header, new RouteRequest(AddressTest.a3, 2, 3, 1))
val rmi = new RouteMessageInfo() val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg)) assertFalse(rmi.isMessageRedundant(msg))
@ -24,12 +24,12 @@ class RouteMessageInfoTest extends TestCase {
* Forward a message with a seqnum that is older than the latest. * Forward a message with a seqnum that is older than the latest.
*/ */
def testSeqNumOlder(): Unit = { def testSeqNumOlder(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0)
val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 0, 0, 0)) val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 0, 0, 0))
val rmi = new RouteMessageInfo() val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1)) assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3) val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3, 0)
val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 2, 0, 0)) val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 2, 0, 0))
assertTrue(rmi.isMessageRedundant(msg2)) assertTrue(rmi.isMessageRedundant(msg2))
} }
@ -38,12 +38,12 @@ class RouteMessageInfoTest extends TestCase {
* Announce a route with a metric that is worse than the existing one. * Announce a route with a metric that is worse than the existing one.
*/ */
def testMetricWorse(): Unit = { def testMetricWorse(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0)
val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 1, 0, 2)) val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 1, 0, 2))
val rmi = new RouteMessageInfo() val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1)) assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2, 0)
val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 1, 0, 4)) val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 1, 0, 4))
assertTrue(rmi.isMessageRedundant(msg2)) assertTrue(rmi.isMessageRedundant(msg2))
} }
@ -52,12 +52,12 @@ class RouteMessageInfoTest extends TestCase {
* Announce route with a better metric. * Announce route with a better metric.
*/ */
def testMetricBetter(): Unit = { def testMetricBetter(): Unit = {
val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0)
val msg1 = new Message(header1, new RouteReply(0, 4)) val msg1 = new Message(header1, new RouteReply(0, 4))
val rmi = new RouteMessageInfo() val rmi = new RouteMessageInfo()
assertFalse(rmi.isMessageRedundant(msg1)) assertFalse(rmi.isMessageRedundant(msg1))
val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2, 0)
val msg2 = new Message(header2, new RouteReply(0, 2)) val msg2 = new Message(header2, new RouteReply(0, 2))
assertFalse(rmi.isMessageRedundant(msg2)) assertFalse(rmi.isMessageRedundant(msg2))
} }
@ -68,7 +68,7 @@ class RouteMessageInfoTest extends TestCase {
def testTimeout(): Unit = { def testTimeout(): Unit = {
val rmi = new RouteMessageInfo() val rmi = new RouteMessageInfo()
DateTimeUtils.setCurrentMillisFixed(DateTime.now.getMillis) DateTimeUtils.setCurrentMillisFixed(DateTime.now.getMillis)
val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0)
val msg = new Message(header, new RouteRequest(AddressTest.a3, 0, 0, 0)) val msg = new Message(header, new RouteRequest(AddressTest.a3, 0, 0, 0))
assertFalse(rmi.isMessageRedundant(msg)) assertFalse(rmi.isMessageRedundant(msg))

View file

@ -45,12 +45,12 @@ class LocalNode(val index: Int, configFolder: File) extends CallbackInterface {
private val databaseFile = new File(configFolder, "database") private val databaseFile = new File(configFolder, "database")
private val keyFolder = new File(configFolder, "keys") private val keyFolder = new File(configFolder, "keys")
private val database = new Database(databaseFile, this)
private val settings = new SettingsInterface { private val settings = new SettingsInterface {
private var values = Map[String, Any]() private var values = Map[String, Any]()
override def get[T](key: String, default: T): T = values.get(key).map(_.asInstanceOf[T]).getOrElse(default) override def get[T](key: String, default: T): T = values.get(key).map(_.asInstanceOf[T]).getOrElse(default)
override def put[T](key: String, value: T): Unit = values += (key -> value.asInstanceOf[Any]) override def put[T](key: String, value: T): Unit = values += (key -> value.asInstanceOf[Any])
} }
private val database = new Database(databaseFile, settings, this)
val crypto = new Crypto(settings, keyFolder) val crypto = new Crypto(settings, keyFolder)
val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 0, port) val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 0, port)

View file

@ -20,40 +20,141 @@ import scalax.file.Path
*/ */
object Main extends App { object Main extends App {
val nodes = createMesh() // NOTE: These tests are somewhat fragile, and might fail randomly. It helps to run only
System.out.println("\n\nAll nodes connected!\n\n") // one of the following functions at a time.
testNeighborSending()
testMeshMessageSending()
testIndirectRelay()
testNeighborRelay()
testMessageDeliveryOnConnect()
testSendDelayed()
testRouteChange()
sendMessages(nodes) private def testNeighborSending(): Unit = {
System.out.println("\n\nMessages sent!\n\n") val node1 = Await.result(createNode(1), Duration.Inf)
val node2 = Await.result(createNode(2), Duration.Inf)
// Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8. connectNodes(node1, node2)
nodes(1).connectionHandler.stop() sendMessage(node1, node2)
System.out.println("Node 1 stopped")
sendMessages(nodes)
System.out.println("\n\nMessages after route change sent!\n\n")
// Create new node 9, send message from node 0 to its address, before actually connecting it. Set(node1, node2).foreach(_.stop())
// The message is automatically delivered when node 9 connects as neighbor.
val node9 = Await.result(createNode(9), Duration.Inf)
val timer = new Timer()
timer.schedule(new TimerTask {
override def run(): Unit = {
connectNodes(nodes(0), node9)
}
}, Duration(10, TimeUnit.SECONDS).toMillis)
sendMessage(nodes(0), node9, 30)
// Create new node 10, send message from node 7 to its address, before connecting it to the mesh. System.out.println("Test neighbor sending successful!")
// The message is delivered after node 7 starts a route discovery triggered by the message buffer. }
val node10 = Await.result(createNode(10), Duration.Inf)
timer.schedule(new TimerTask { private def testNeighborRelay(): Unit = {
override def run(): Unit = { val nodes = createNodes(3)
connectNodes(nodes(0), node10)
timer.cancel() connectNodes(nodes(0), nodes(1))
}
}, Duration(5, TimeUnit.SECONDS).toMillis) val timer = new Timer()
sendMessage(nodes(7), node10, 30) timer.schedule(new TimerTask {
System.out.println("\n\nMessages after delay sent!\n\n") override def run(): Unit = {
nodes(0).stop()
connectNodes(nodes(1), nodes(2))
}
}, Duration(10, TimeUnit.SECONDS).toMillis)
sendMessage(nodes(0), nodes(2), 30)
timer.cancel()
nodes.foreach(_.stop())
System.out.println("Test neighbor relay successful!")
}
private def testIndirectRelay(): Unit = {
val nodes = createNodes(5)
connectNodes(nodes(0), nodes(1))
connectNodes(nodes(1), nodes(2))
connectNodes(nodes(2), nodes(3))
val timer = new Timer()
timer.schedule(new TimerTask {
override def run(): Unit = {
nodes(0).stop()
connectNodes(nodes(3), nodes(4))
}
}, Duration(10, TimeUnit.SECONDS).toMillis)
sendMessage(nodes(0), nodes(4), 30)
timer.cancel()
nodes.foreach(_.stop())
System.out.println("Test indirect sending successful!")
}
private def testMeshMessageSending(): Unit = {
val nodes = createMesh()
sendMessages(nodes)
nodes.foreach(_.stop())
System.out.println("Test mesh message sending successful!")
}
/**
* Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8.
*/
private def testRouteChange() {
val nodes = createMesh()
nodes(1).connectionHandler.stop()
sendMessages(nodes)
nodes.foreach(_.stop())
System.out.println("Test route change successful!")
}
/**
* Create new node 9, send message from node 0 to its address, before actually connecting it.
* The message is automatically delivered when node 9 connects as neighbor.
*/
private def testMessageDeliveryOnConnect() {
val nodes = createMesh()
val node9 = Await.result(createNode(9), Duration.Inf)
val timer = new Timer()
timer.schedule(new TimerTask {
override def run(): Unit = {
connectNodes(nodes(0), node9)
timer.cancel()
}
}, Duration(10, TimeUnit.SECONDS).toMillis)
sendMessage(nodes(0), node9, 30)
(nodes :+ node9).foreach(_.stop())
System.out.println("Test message delivery on connect successful!")
}
/**
* Create new node 10, send message from node 7 to its address, before connecting it to the mesh.
* The message is delivered after node 7 starts a route discovery triggered by the message buffer.
*/
private def testSendDelayed(): Unit = {
val nodes = createMesh()
val timer = new Timer()
val node10 = Await.result(createNode(10), Duration.Inf)
timer.schedule(new TimerTask {
override def run(): Unit = {
connectNodes(nodes(0), node10)
timer.cancel()
}
}, Duration(5, TimeUnit.SECONDS).toMillis)
sendMessage(nodes(7), node10, 30)
(nodes :+ node10).foreach(_.stop())
System.out.println("Test send delayed successful!")
}
private def createNodes(count: Int): Seq[LocalNode] = {
val nodes = Await.result(Future.sequence((0 until count).map(createNode)), Duration.Inf)
nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}"))
nodes
}
/** /**
* Creates a new mesh with a predefined layout. * Creates a new mesh with a predefined layout.
@ -68,8 +169,7 @@ object Main extends App {
* @return List of [[LocalNode]]s, ordered from 0 to 8. * @return List of [[LocalNode]]s, ordered from 0 to 8.
*/ */
private def createMesh(): Seq[LocalNode] = { private def createMesh(): Seq[LocalNode] = {
val nodes = Await.result(Future.sequence(0.to(8).map(createNode)), Duration.Inf) val nodes = createNodes(9)
sys.addShutdownHook(nodes.foreach(_.stop()))
connectNodes(nodes(0), nodes(1)) connectNodes(nodes(0), nodes(1))
connectNodes(nodes(0), nodes(2)) connectNodes(nodes(0), nodes(2))
@ -82,7 +182,6 @@ object Main extends App {
connectNodes(nodes(3), nodes(7)) connectNodes(nodes(3), nodes(7))
connectNodes(nodes(0), nodes(8)) connectNodes(nodes(0), nodes(8))
connectNodes(nodes(7), nodes(8)) connectNodes(nodes(7), nodes(8))
nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}"))
nodes nodes
} }
@ -119,7 +218,6 @@ object Main extends App {
sendMessage(nodes(3), nodes(5)) sendMessage(nodes(3), nodes(5))
sendMessage(nodes(4), nodes(6)) sendMessage(nodes(4), nodes(6))
sendMessage(nodes(2), nodes(3)) sendMessage(nodes(2), nodes(3))
sendMessage(nodes(0), nodes(3))
sendMessage(nodes(3), nodes(6)) sendMessage(nodes(3), nodes(6))
sendMessage(nodes(3), nodes(2)) sendMessage(nodes(3), nodes(2))
} }

View file

@ -22,7 +22,7 @@ object Main extends App with CallbackInterface {
private lazy val settings = new Settings(ConfigFile) private lazy val settings = new Settings(ConfigFile)
private lazy val crypto = new Crypto(settings, KeyFolder) private lazy val crypto = new Crypto(settings, KeyFolder)
private lazy val database = new Database(DatabaseFile, this) private lazy val database = new Database(DatabaseFile, settings, this)
private lazy val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 7) private lazy val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 7)
init() init()