diff --git a/PROTOCOL.md b/PROTOCOL.md index 209754c..1923f43 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -85,7 +85,7 @@ version, type and ID, followed by the length of the message. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Version | Protocol-Type | Hop Limit | Hop Count | + | Version | Protocol-Type | Tokens | Hop Limit | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -111,8 +111,8 @@ where such a packet came from MAY be closed. Protocol-Type is one of those specified in section Protocol Messages, or 255 for Content Messages. -Hop Limit SHOULD be set to `20` on message creation, and -MUST NOT be changed by a forwarding node. +Tokens is the number of times this message should be copied to +different relays. Hop Count specifies the number of nodes a message may pass. When creating a package, it is initialized to 0. Whenever a node forwards diff --git a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala index 42d37e1..aa7ad4d 100644 --- a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala +++ b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothInterface.scala @@ -12,6 +12,7 @@ import com.nutomic.ensichat.core.body.ConnectionInfo import com.nutomic.ensichat.core.interfaces.{SettingsInterface, TransmissionInterface} import com.nutomic.ensichat.core.{Address, ConnectionHandler, Message} import com.nutomic.ensichat.service.ChatService +import org.joda.time.{DateTime, Duration} import scala.collection.immutable.HashMap @@ -169,12 +170,13 @@ class BluetoothInterface(context: Context, mainHandler: Handler, /** * Removes device from active connections. */ - def onConnectionClosed(device: Device, socket: BluetoothSocket): Unit = { - val address = getAddressForDevice(device.id) - devices -= device.id - connections -= device.id - addressDeviceMap = addressDeviceMap.filterNot(_._2 == device.id) - connectionHandler.onConnectionClosed(address) + def onConnectionClosed(connectionOpened: DateTime, deviceId: Device.ID): Unit = { + val address = getAddressForDevice(deviceId) + devices -= deviceId + connections -= deviceId + addressDeviceMap = addressDeviceMap.filterNot(_._2 == deviceId) + val connectionDuration = new Duration(connectionOpened, DateTime.now) + connectionHandler.onConnectionClosed(address, connectionDuration) } /** diff --git a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothTransferThread.scala b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothTransferThread.scala index 240a2fd..2f389b6 100644 --- a/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothTransferThread.scala +++ b/android/src/main/scala/com/nutomic/ensichat/bluetooth/BluetoothTransferThread.scala @@ -9,6 +9,7 @@ import com.nutomic.ensichat.core.Message.ReadMessageException import com.nutomic.ensichat.core.body.ConnectionInfo import com.nutomic.ensichat.core.header.MessageHeader import com.nutomic.ensichat.core.{Address, Crypto, Message} +import org.joda.time.DateTime /** * Transfers data between connnected devices. @@ -17,8 +18,11 @@ import com.nutomic.ensichat.core.{Address, Crypto, Message} * @param socket An open socket to the given device. * @param onReceive Called when a message was received from the other device. */ -class BluetoothTransferThread(context: Context, device: Device, socket: BluetoothSocket, handler: BluetoothInterface, - crypto: Crypto, onReceive: (Message, Device.ID) => Unit) extends Thread { +class BluetoothTransferThread(context: Context, device: Device, socket: BluetoothSocket, + handler: BluetoothInterface, crypto: Crypto, + onReceive: (Message, Device.ID) => Unit) extends Thread { + + private val connectionOpened = DateTime.now private val Tag = "TransferThread" @@ -61,7 +65,7 @@ class BluetoothTransferThread(context: Context, device: Device, socket: Bluetoot new IntentFilter(BluetoothDevice.ACTION_ACL_DISCONNECTED)) send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type, - Address.Null, Address.Null, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) + Address.Null, Address.Null, 0, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) while (socket.isConnected) { try { @@ -102,7 +106,7 @@ class BluetoothTransferThread(context: Context, device: Device, socket: Bluetoot } catch { case e: IOException => Log.e(Tag, "Failed to close socket", e); } finally { - handler.onConnectionClosed(new Device(device.btDevice.get, false), null) + handler.onConnectionClosed(connectionOpened, device.id) } } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala index 408ca89..3ff4bde 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/ConnectionHandler.scala @@ -1,15 +1,15 @@ package com.nutomic.ensichat.core import java.security.InvalidKeyException -import java.util.{TimerTask, Timer, Date} +import java.util.Date import com.nutomic.ensichat.core.body._ -import com.nutomic.ensichat.core.header.{ContentHeader, MessageHeader} +import com.nutomic.ensichat.core.header.{AbstractHeader, ContentHeader, MessageHeader} import com.nutomic.ensichat.core.interfaces._ import com.nutomic.ensichat.core.internet.InternetInterface import com.nutomic.ensichat.core.util._ import com.typesafe.scalalogging.Logger -import org.joda.time.{DateTime, Duration} +import org.joda.time.Duration import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -27,8 +27,6 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, private val logger = Logger(this.getClass) - private val CheckMessageRetryInterval = Duration.standardMinutes(1) - private var transmissionInterfaces = Set[TransmissionInterface]() private lazy val seqNumGenerator = new SeqNumGenerator(settings) @@ -83,12 +81,13 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, FutureHelper { val messageId = settings.get("message_id", 0L) val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(), - body.contentType, Some(messageId), Some(new Date())) + body.contentType, Some(messageId), Some(new Date()), AbstractHeader.InitialForwardingTokens) settings.put("message_id", messageId + 1) val msg = new Message(header, body) val encrypted = crypto.encryptAndSign(msg) router.forwardMessage(encrypted) + forwardMessageToRelays(encrypted) onNewMessage(msg) } } @@ -98,7 +97,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, val seqNum = seqNumGenerator.next() val targetSeqNum = localRoutesInfo.getRoute(target).map(_.seqNum).getOrElse(-1) val body = new RouteRequest(target, seqNum, targetSeqNum, 0) - val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum) + val header = new MessageHeader(body.protocolType, crypto.localAddress, Address.Broadcast, seqNum, 0) val signed = crypto.sign(new Message(header, body)) logger.trace(s"sending new $signed") @@ -108,7 +107,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, private def replyRoute(target: Address, replyTo: Address): Unit = { val seqNum = seqNumGenerator.next() val body = new RouteReply(seqNum, 0) - val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum) + val header = new MessageHeader(body.protocolType, crypto.localAddress, replyTo, seqNum, 0) val signed = crypto.sign(new Message(header, body)) logger.trace(s"sending new $signed") @@ -118,7 +117,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, private def routeError(address: Address, packetSource: Option[Address]): Unit = { val destination = packetSource.getOrElse(Address.Broadcast) val header = new MessageHeader(RouteError.Type, crypto.localAddress, destination, - seqNumGenerator.next()) + seqNumGenerator.next(), 0) val seqNum = localRoutesInfo.getRoute(address).map(_.seqNum).getOrElse(-1) val body = new RouteError(address, seqNum) @@ -211,6 +210,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, if (msg.header.target != crypto.localAddress) { router.forwardMessage(msg) + forwardMessageToRelays(msg) return } @@ -234,6 +234,20 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, onNewMessage(plainMsg) } + private def forwardMessageToRelays(message: Message): Unit = { + var tokens = message.header.tokens + val relays = database.pickLongestConnectionDevice(connections()) + var index = 0 + while (tokens > 1) { + val forwardTokens = tokens / 2 + val headerCopy = message.header.asInstanceOf[ContentHeader].copy(tokens = forwardTokens) + router.forwardMessage(message.copy(header = headerCopy), relays.lift(index)) + tokens -= forwardTokens + database.updateMessageForwardingTokens(message, tokens) + index += 1 + } + } + /** * Tries to send messages in [[MessageBuffer]] again, after we acquired a new route. */ @@ -244,11 +258,8 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, } private def noRouteFound(message: Message): Unit = { - if (message.header.origin == crypto.localAddress) { - messageBuffer.addMessage(message) - requestRoute(message.header.target) - } else - routeError(message.header.target, Option(message.header.origin)) + messageBuffer.addMessage(message) + requestRoute(message.header.target) } /** @@ -317,13 +328,23 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database, settings.get(SettingsInterface.KeyUserStatus, ""))) callbacks.onConnectionsChanged() resendMissingRouteMessages() + messageBuffer.getAllMessages + .filter(_.header.tokens > 1) + .foreach(forwardMessageToRelays) true } - def onConnectionClosed(address: Address): Unit = { + /** + * Called by [[TransmissionInterface]] when a connection is closed. + * + * @param address The address of the connected device. + * @param duration The time that we were connected to the device. + */ + def onConnectionClosed(address: Address, duration: Duration): Unit = { localRoutesInfo.connectionClosed(address) .foreach(routeError(_, None)) callbacks.onConnectionsChanged() + database.insertOrUpdateKnownDevice(address, duration) } def connections(): Set[Address] = transmissionInterfaces.flatMap(_.getConnections) diff --git a/core/src/main/scala/com/nutomic/ensichat/core/Router.scala b/core/src/main/scala/com/nutomic/ensichat/core/Router.scala index 02159b8..fd0de6c 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/Router.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/Router.scala @@ -7,6 +7,8 @@ import com.nutomic.ensichat.core.util.LocalRoutesInfo object Router extends Comparator[Int] { + private val HopLimit = 20 + /** * Compares which sequence number is newer. * @@ -50,7 +52,7 @@ private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message) * true. */ def forwardMessage(msg: Message, nextHopOption: Option[Address] = None): Unit = { - if (msg.header.hopCount + 1 >= msg.header.hopLimit) + if (msg.header.hopCount + 1 >= Router.HopLimit) return val nextHop = nextHopOption.getOrElse(msg.header.target) @@ -79,10 +81,8 @@ private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message) */ private def incHopCount(msg: Message): Message = { val updatedHeader = msg.header match { - case ch: ContentHeader => new ContentHeader(ch.origin, ch.target, ch.seqNum, ch.contentType, - ch.messageId, ch.time, ch.hopCount + 1, ch.hopLimit) - case mh: MessageHeader => new MessageHeader(mh.protocolType, mh.origin, mh.target, mh.seqNum, - mh.hopCount + 1, mh.hopLimit) + case ch: ContentHeader => ch.copy(hopCount = ch.hopCount + 1) + case mh: MessageHeader => mh.copy(hopCount = mh.hopCount + 1) } new Message(updatedHeader, msg.crypto, msg.body) } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/header/AbstractHeader.scala b/core/src/main/scala/com/nutomic/ensichat/core/header/AbstractHeader.scala index a2a4589..97e1073 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/header/AbstractHeader.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/header/AbstractHeader.scala @@ -8,7 +8,7 @@ import com.nutomic.ensichat.core.util.BufferUtils object AbstractHeader { - val DefaultHopLimit = 20 + val InitialForwardingTokens = 3 val Version = 0 @@ -25,7 +25,7 @@ object AbstractHeader { trait AbstractHeader { def protocolType: Int - def hopLimit: Int + def tokens: Int def hopCount: Int def origin: Address def target: Address @@ -41,7 +41,7 @@ trait AbstractHeader { BufferUtils.putUnsignedByte(b, AbstractHeader.Version) BufferUtils.putUnsignedByte(b, protocolType) - BufferUtils.putUnsignedByte(b, hopLimit) + BufferUtils.putUnsignedByte(b, tokens) BufferUtils.putUnsignedByte(b, hopCount) BufferUtils.putUnsignedInt(b, length + contentLength) @@ -63,7 +63,7 @@ trait AbstractHeader { override def equals(a: Any): Boolean = a match { case o: AbstractHeader => protocolType == o.protocolType && - hopLimit == o.hopLimit && + tokens == o.tokens && hopCount == o.hopCount && origin == o.origin && target == o.target && diff --git a/core/src/main/scala/com/nutomic/ensichat/core/header/ContentHeader.scala b/core/src/main/scala/com/nutomic/ensichat/core/header/ContentHeader.scala index dd9c7a7..f58ce36 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/header/ContentHeader.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/header/ContentHeader.scala @@ -25,7 +25,7 @@ object ContentHeader { val time = BufferUtils.getUnsignedInt(b) val ch = new ContentHeader(mh.origin, mh.target, mh.seqNum, contentType, Some(messageId), - Some(new Date(time * 1000)), mh.hopCount) + Some(new Date(time * 1000)), mh.tokens, mh.hopCount) val remaining = new Array[Byte](b.remaining()) b.get(remaining, 0, b.remaining()) @@ -45,8 +45,8 @@ final case class ContentHeader(override val origin: Address, contentType: Int, override val messageId: Some[Long], override val time: Some[Date], - override val hopCount: Int = 0, - override val hopLimit: Int = AbstractHeader.DefaultHopLimit) + override val tokens: Int, + override val hopCount: Int = 0) extends AbstractHeader { override val protocolType = ContentHeader.ContentMessageType diff --git a/core/src/main/scala/com/nutomic/ensichat/core/header/MessageHeader.scala b/core/src/main/scala/com/nutomic/ensichat/core/header/MessageHeader.scala index 9b77509..2f859a4 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/header/MessageHeader.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/header/MessageHeader.scala @@ -23,7 +23,7 @@ object MessageHeader { if (version != AbstractHeader.Version) throw new ReadMessageException("Failed to parse message with unsupported version " + version) val protocolType = BufferUtils.getUnsignedByte(b) - val hopLimit = BufferUtils.getUnsignedByte(b) + val tokens = BufferUtils.getUnsignedByte(b) val hopCount = BufferUtils.getUnsignedByte(b) val length = BufferUtils.getUnsignedInt(b) @@ -34,7 +34,7 @@ object MessageHeader { val seqNum = BufferUtils.getUnsignedShort(b) - (new MessageHeader(protocolType, origin, target, seqNum, hopCount, hopLimit), length.toInt) + (new MessageHeader(protocolType, origin, target, seqNum, tokens, hopCount), length.toInt) } } @@ -48,8 +48,8 @@ final case class MessageHeader(override val protocolType: Int, override val origin: Address, override val target: Address, override val seqNum: Int, - override val hopCount: Int = 0, - override val hopLimit: Int = AbstractHeader.DefaultHopLimit) + override val tokens: Int, + override val hopCount: Int = 0) extends AbstractHeader { def length: Int = MessageHeader.Length diff --git a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala index 4235964..b559128 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetConnectionThread.scala @@ -8,6 +8,7 @@ import com.nutomic.ensichat.core.body.ConnectionInfo import com.nutomic.ensichat.core.header.MessageHeader import com.nutomic.ensichat.core.{Address, Crypto, Message} import com.typesafe.scalalogging.Logger +import org.joda.time.DateTime /** * Encapsulates an active connection to another node. @@ -17,6 +18,8 @@ private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto, onReceive: (Message, InternetConnectionThread) => Unit) extends Thread { + val connectionOpened = DateTime.now + private val logger = Logger(this.getClass) private val inStream: InputStream = @@ -47,7 +50,7 @@ private[core] class InternetConnectionThread(socket: Socket, crypto: Crypto, logger.info("Connection opened to " + socket.getInetAddress) send(crypto.sign(new Message(new MessageHeader(ConnectionInfo.Type, - Address.Null, Address.Null, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) + Address.Null, Address.Null, 0, 0), new ConnectionInfo(crypto.getLocalPublicKey)))) try { socket.setKeepAlive(true) diff --git a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala index f68905d..d3f9ce1 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/internet/InternetInterface.scala @@ -7,6 +7,7 @@ import com.nutomic.ensichat.core.interfaces.{SettingsInterface, TransmissionInte import com.nutomic.ensichat.core.util.FutureHelper import com.nutomic.ensichat.core.{Address, ConnectionHandler, Crypto, Message} import com.typesafe.scalalogging.Logger +import org.joda.time.{DateTime, Duration} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -102,7 +103,8 @@ private[core] class InternetInterface(connectionHandler: ConnectionHandler, cryp logger.trace("Connection closed to " + ad) connections -= connectionThread addressDeviceMap -= ad - connectionHandler.onConnectionClosed(ad) + val connectionDuration = new Duration(connectionThread.connectionOpened, DateTime.now) + connectionHandler.onConnectionClosed(ad, connectionDuration) } } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala b/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala index ec8d464..8cd4a3f 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/util/Database.scala @@ -1,13 +1,15 @@ package com.nutomic.ensichat.core.util import java.io.File +import java.sql.DriverManager import java.util.Date import com.nutomic.ensichat.core.body.Text import com.nutomic.ensichat.core.header.ContentHeader -import com.nutomic.ensichat.core.interfaces.CallbackInterface +import com.nutomic.ensichat.core.interfaces.{CallbackInterface, SettingsInterface} import com.nutomic.ensichat.core.{Address, Message, User} import com.typesafe.scalalogging.Logger +import org.joda.time import slick.driver.H2Driver.api._ import scala.concurrent.Await @@ -19,10 +21,15 @@ import scala.concurrent.duration.Duration * * @param path The database file. */ -class Database(path: File, callbackInterface: CallbackInterface) { +class Database(path: File, settings: SettingsInterface, callbackInterface: CallbackInterface) { private val logger = Logger(this.getClass) + private val DatabaseVersionKey = "database_version" + private val DatabaseVersion = 2 + + private val DatabasePath = "jdbc:h2:" + path.getAbsolutePath + ";DATABASE_TO_UPPER=false" + private class Messages(tag: Tag) extends Table[Message](tag, "MESSAGES") { def id = primaryKey("id", (origin, messageId)) def origin = column[String]("origin") @@ -30,20 +37,23 @@ class Database(path: File, callbackInterface: CallbackInterface) { def messageId = column[Long]("message_id") def text = column[String]("text") def date = column[Long]("date") - def * = (origin, target, messageId, text, date).<> [Message, (String, String, Long, String, Long)]( { tuple => - val header = new ContentHeader(new Address(tuple._1), - new Address(tuple._2), - -1, - Text.Type, - Some(tuple._3), - Some(new Date(tuple._5))) - val body = new Text(tuple._4) - new Message(header, body) - }, { message => + def tokens = column[Int]("tokens") + def * = (origin, target, messageId, text, date, tokens) <> [Message, (String, String, Long, String, Long, Int)]( { + tuple => + val header = new ContentHeader(new Address(tuple._1), + new Address(tuple._2), + -1, + Text.Type, + Some(tuple._3), + Some(new Date(tuple._5)), + tuple._6) + val body = new Text(tuple._4) + new Message(header, body) + }, message => Option((message.header.origin.toString(), message.header.target.toString(), message.header.messageId.get, message.body.asInstanceOf[Text].text, - message.header.time.get.getTime)) - }) + message.header.time.get.getTime, message.header.tokens)) + ) } private val messages = TableQuery[Messages] @@ -51,12 +61,21 @@ class Database(path: File, callbackInterface: CallbackInterface) { def address = column[String]("address", O.PrimaryKey) def name = column[String]("name") def status = column[String]("status") - def wrappedAddress = address.<> [Address, String](new Address(_), a => Option(a.toString())) + def wrappedAddress = address <> [Address, String](new Address(_), a => Option(a.toString)) def * = (wrappedAddress, name, status) <> (User.tupled, User.unapply) } private val contacts = TableQuery[Contacts] - private val db = Database.forURL("jdbc:h2:" + path.getAbsolutePath, driver = "org.h2.Driver") + private class KnownDevices(tag: Tag) extends Table[(Address, time.Duration)](tag, "KNOWN_DEVICES") { + def address = column[String]("address", O.PrimaryKey) + def totalConnectionSeconds = column[Long]("total_connection_seconds") + def * = (address, totalConnectionSeconds) <> [(Address, time.Duration), (String, Long)]( + tuple => (new Address(tuple._1), time.Duration.standardSeconds(tuple._2)), + tuple => Option((tuple._1.toString, tuple._2.getStandardSeconds))) + } + private val knownDevices = TableQuery[KnownDevices] + + private val db = Database.forURL(DatabasePath, driver = "org.h2.Driver") // Create tables if database doesn't exist. { @@ -64,7 +83,25 @@ class Database(path: File, callbackInterface: CallbackInterface) { val dbFile = new File(path.getAbsolutePath + ".mv.db") if (!dbFile.exists()) { logger.info("Database does not exist, creating tables") - Await.result(db.run((messages.schema ++ contacts.schema).create), Duration.Inf) + val query = (messages.schema ++ contacts.schema ++ knownDevices.schema).create + Await.result(db.run(query), Duration.Inf) + settings.put(DatabaseVersionKey, DatabaseVersion) + } + } + + // Apparently, slick doesn't support ALTER TABLE, so we have to write raw SQL for this... + { + val oldVersion = settings.get(DatabaseVersionKey, 0) + if (oldVersion != DatabaseVersion) { + logger.info(s"Upgrading database from version $oldVersion to $DatabaseVersion") + val connection = DriverManager.getConnection(DatabasePath) + if (oldVersion <= 2) { + connection.createStatement().executeUpdate("ALTER TABLE MESSAGES ADD COLUMN (tokens INT);") + connection.commit() + Await.result(db.run(knownDevices.schema.create), Duration.Inf) + } + connection.close() + settings.put(DatabaseVersionKey, DatabaseVersion) } } @@ -122,4 +159,30 @@ class Database(path: File, callbackInterface: CallbackInterface) { callbackInterface.onContactsUpdated() } + def insertOrUpdateKnownDevice(address: Address, connectionTime: time.Duration): Unit = { + val query = knownDevices.insertOrUpdate((address, connectionTime)) + Await.result(db.run(query), Duration.Inf) + } + + /** + * Returns neighbors sorted by connection time, according to [[KnownDevices]]. + */ + def pickLongestConnectionDevice(connections: Set[Address]): List[Address] = { + val map = Await.result(db.run(knownDevices.result), Duration.Inf).toMap + connections + .toList + .sortBy(map(_).getMillis) + .reverse + } + + def updateMessageForwardingTokens(message: Message, tokens: Int): Unit = { + val query = messages.filter { c => + c.origin === message.header.origin.toString && + c.messageId === message.header.messageId + } + .map(_.tokens) + .update(tokens) + Await.result(db.run(query), Duration.Inf) + } + } diff --git a/core/src/main/scala/com/nutomic/ensichat/core/util/MessageBuffer.scala b/core/src/main/scala/com/nutomic/ensichat/core/util/MessageBuffer.scala index b82da3c..988a405 100644 --- a/core/src/main/scala/com/nutomic/ensichat/core/util/MessageBuffer.scala +++ b/core/src/main/scala/com/nutomic/ensichat/core/util/MessageBuffer.scala @@ -82,6 +82,8 @@ class MessageBuffer(retryMessageSending: (Address) => Unit) { ret.map(_.message) } + def getAllMessages: Set[Message] = values.map(_.message) + private def handleTimeouts(): Unit = { values = values.filter { e => e.retryCount < MaxRetryCount diff --git a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala index 2af4609..0b624a7 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/MessageTest.scala @@ -40,7 +40,7 @@ class MessageTest extends TestCase { } def testSerializeSigned(): Unit = { - val header = new MessageHeader(ConnectionInfo.Type, AddressTest.a4, AddressTest.a2, 0) + val header = new MessageHeader(ConnectionInfo.Type, AddressTest.a4, AddressTest.a2, 0, 3) val m = new Message(header, ConnectionInfoTest.generateCi()) val signed = crypto.sign(m) diff --git a/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala index 88ca751..677626a 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/RouterTest.scala @@ -46,7 +46,7 @@ class RouterTest extends TestCase { assertEquals(msg.header.seqNum, m.header.seqNum) assertEquals(msg.header.protocolType, m.header.protocolType) assertEquals(msg.header.hopCount + 1, m.header.hopCount) - assertEquals(msg.header.hopLimit, m.header.hopLimit) + assertEquals(msg.header.tokens, m.header.tokens) assertEquals(msg.body, m.body) assertEquals(msg.crypto, m.crypto) }, _ => ()) @@ -93,14 +93,14 @@ class RouterTest extends TestCase { def testHopLimit(): Unit = Range(19, 22).foreach { i => val msg = new Message( - new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), i), new Text("")) + new ContentHeader(AddressTest.a1, AddressTest.a2, 1, 1, Some(1), Some(new Date()), 3, i), new Text("")) val router = new Router(new LocalRoutesInfo(neighbors), (a, m) => fail(), _ => ()) router.forwardMessage(msg) } private def generateMessage(sender: Address, receiver: Address, seqNum: Int): Message = { val header = new ContentHeader(sender, receiver, seqNum, UserInfo.Type, Some(5), - Some(new GregorianCalendar(2014, 6, 10).getTime)) + Some(new GregorianCalendar(2014, 6, 10).getTime), 3) new Message(header, new UserInfo("", "")) } diff --git a/core/src/test/scala/com/nutomic/ensichat/core/header/MessageHeaderTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/header/MessageHeaderTest.scala index cd444b9..e1c660b 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/header/MessageHeaderTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/header/MessageHeaderTest.scala @@ -11,9 +11,9 @@ object MessageHeaderTest { 0) val h2 = new MessageHeader(ContentHeader.ContentMessageType, Address.Null, Address.Broadcast, - ContentHeader.SeqNumRange.last, 0xff) + ContentHeader.SeqNumRange.last, 0xff, 3) - val h3 = new MessageHeader(ContentHeader.ContentMessageType, Address.Broadcast, Address.Null, 0) + val h3 = new MessageHeader(ContentHeader.ContentMessageType, Address.Broadcast, Address.Null, 0, 3) val headers = Set(h1, h2, h3) diff --git a/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala b/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala index bc6e2b5..b8659af 100644 --- a/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala +++ b/core/src/test/scala/com/nutomic/ensichat/core/util/RouteMessageInfoTest.scala @@ -13,7 +13,7 @@ class RouteMessageInfoTest extends TestCase { * Test case in which we have an entry with the same type, origin and target. */ def testSameMessage(): Unit = { - val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0) val msg = new Message(header, new RouteRequest(AddressTest.a3, 2, 3, 1)) val rmi = new RouteMessageInfo() assertFalse(rmi.isMessageRedundant(msg)) @@ -24,12 +24,12 @@ class RouteMessageInfoTest extends TestCase { * Forward a message with a seqnum that is older than the latest. */ def testSeqNumOlder(): Unit = { - val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0) val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 0, 0, 0)) val rmi = new RouteMessageInfo() assertFalse(rmi.isMessageRedundant(msg1)) - val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3) + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 3, 0) val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 2, 0, 0)) assertTrue(rmi.isMessageRedundant(msg2)) } @@ -38,12 +38,12 @@ class RouteMessageInfoTest extends TestCase { * Announce a route with a metric that is worse than the existing one. */ def testMetricWorse(): Unit = { - val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0) val msg1 = new Message(header1, new RouteRequest(AddressTest.a3, 1, 0, 2)) val rmi = new RouteMessageInfo() assertFalse(rmi.isMessageRedundant(msg1)) - val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2, 0) val msg2 = new Message(header2, new RouteRequest(AddressTest.a3, 1, 0, 4)) assertTrue(rmi.isMessageRedundant(msg2)) } @@ -52,12 +52,12 @@ class RouteMessageInfoTest extends TestCase { * Announce route with a better metric. */ def testMetricBetter(): Unit = { - val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val header1 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0) val msg1 = new Message(header1, new RouteReply(0, 4)) val rmi = new RouteMessageInfo() assertFalse(rmi.isMessageRedundant(msg1)) - val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2) + val header2 = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 2, 0) val msg2 = new Message(header2, new RouteReply(0, 2)) assertFalse(rmi.isMessageRedundant(msg2)) } @@ -68,7 +68,7 @@ class RouteMessageInfoTest extends TestCase { def testTimeout(): Unit = { val rmi = new RouteMessageInfo() DateTimeUtils.setCurrentMillisFixed(DateTime.now.getMillis) - val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1) + val header = new MessageHeader(RouteRequest.Type, AddressTest.a1, AddressTest.a2, 1, 0) val msg = new Message(header, new RouteRequest(AddressTest.a3, 0, 0, 0)) assertFalse(rmi.isMessageRedundant(msg)) diff --git a/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala b/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala index 2d88e60..4a59ec9 100644 --- a/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala +++ b/integration/src/main/scala/com.nutomic.ensichat.integration/LocalNode.scala @@ -45,12 +45,12 @@ class LocalNode(val index: Int, configFolder: File) extends CallbackInterface { private val databaseFile = new File(configFolder, "database") private val keyFolder = new File(configFolder, "keys") - private val database = new Database(databaseFile, this) private val settings = new SettingsInterface { private var values = Map[String, Any]() override def get[T](key: String, default: T): T = values.get(key).map(_.asInstanceOf[T]).getOrElse(default) override def put[T](key: String, value: T): Unit = values += (key -> value.asInstanceOf[Any]) } + private val database = new Database(databaseFile, settings, this) val crypto = new Crypto(settings, keyFolder) val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 0, port) diff --git a/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala b/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala index c4b4476..3e7dee7 100644 --- a/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala +++ b/integration/src/main/scala/com.nutomic.ensichat.integration/Main.scala @@ -20,40 +20,141 @@ import scalax.file.Path */ object Main extends App { - val nodes = createMesh() - System.out.println("\n\nAll nodes connected!\n\n") + // NOTE: These tests are somewhat fragile, and might fail randomly. It helps to run only + // one of the following functions at a time. + testNeighborSending() + testMeshMessageSending() + testIndirectRelay() + testNeighborRelay() + testMessageDeliveryOnConnect() + testSendDelayed() + testRouteChange() - sendMessages(nodes) - System.out.println("\n\nMessages sent!\n\n") + private def testNeighborSending(): Unit = { + val node1 = Await.result(createNode(1), Duration.Inf) + val node2 = Await.result(createNode(2), Duration.Inf) - // Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8. - nodes(1).connectionHandler.stop() - System.out.println("Node 1 stopped") - sendMessages(nodes) - System.out.println("\n\nMessages after route change sent!\n\n") + connectNodes(node1, node2) + sendMessage(node1, node2) - // Create new node 9, send message from node 0 to its address, before actually connecting it. - // The message is automatically delivered when node 9 connects as neighbor. - val node9 = Await.result(createNode(9), Duration.Inf) - val timer = new Timer() - timer.schedule(new TimerTask { - override def run(): Unit = { - connectNodes(nodes(0), node9) - } - }, Duration(10, TimeUnit.SECONDS).toMillis) - sendMessage(nodes(0), node9, 30) + Set(node1, node2).foreach(_.stop()) - // Create new node 10, send message from node 7 to its address, before connecting it to the mesh. - // The message is delivered after node 7 starts a route discovery triggered by the message buffer. - val node10 = Await.result(createNode(10), Duration.Inf) - timer.schedule(new TimerTask { - override def run(): Unit = { - connectNodes(nodes(0), node10) - timer.cancel() - } - }, Duration(5, TimeUnit.SECONDS).toMillis) - sendMessage(nodes(7), node10, 30) - System.out.println("\n\nMessages after delay sent!\n\n") + System.out.println("Test neighbor sending successful!") + } + + private def testNeighborRelay(): Unit = { + val nodes = createNodes(3) + + connectNodes(nodes(0), nodes(1)) + + val timer = new Timer() + timer.schedule(new TimerTask { + override def run(): Unit = { + nodes(0).stop() + connectNodes(nodes(1), nodes(2)) + } + }, Duration(10, TimeUnit.SECONDS).toMillis) + sendMessage(nodes(0), nodes(2), 30) + + timer.cancel() + nodes.foreach(_.stop()) + + System.out.println("Test neighbor relay successful!") + } + + private def testIndirectRelay(): Unit = { + val nodes = createNodes(5) + + + connectNodes(nodes(0), nodes(1)) + connectNodes(nodes(1), nodes(2)) + connectNodes(nodes(2), nodes(3)) + + val timer = new Timer() + timer.schedule(new TimerTask { + override def run(): Unit = { + nodes(0).stop() + connectNodes(nodes(3), nodes(4)) + } + }, Duration(10, TimeUnit.SECONDS).toMillis) + sendMessage(nodes(0), nodes(4), 30) + + timer.cancel() + nodes.foreach(_.stop()) + + System.out.println("Test indirect sending successful!") + } + + private def testMeshMessageSending(): Unit = { + val nodes = createMesh() + + sendMessages(nodes) + + nodes.foreach(_.stop()) + + System.out.println("Test mesh message sending successful!") + } + + /** + * Stop node 1, forcing route errors and messages to use the (longer) path via nodes 7 and 8. + */ + private def testRouteChange() { + val nodes = createMesh() + nodes(1).connectionHandler.stop() + sendMessages(nodes) + + nodes.foreach(_.stop()) + + System.out.println("Test route change successful!") + } + + /** + * Create new node 9, send message from node 0 to its address, before actually connecting it. + * The message is automatically delivered when node 9 connects as neighbor. + */ + private def testMessageDeliveryOnConnect() { + val nodes = createMesh() + val node9 = Await.result(createNode(9), Duration.Inf) + val timer = new Timer() + timer.schedule(new TimerTask { + override def run(): Unit = { + connectNodes(nodes(0), node9) + timer.cancel() + } + }, Duration(10, TimeUnit.SECONDS).toMillis) + sendMessage(nodes(0), node9, 30) + + (nodes :+ node9).foreach(_.stop()) + + System.out.println("Test message delivery on connect successful!") + } + + /** + * Create new node 10, send message from node 7 to its address, before connecting it to the mesh. + * The message is delivered after node 7 starts a route discovery triggered by the message buffer. + */ + private def testSendDelayed(): Unit = { + val nodes = createMesh() + val timer = new Timer() + val node10 = Await.result(createNode(10), Duration.Inf) + timer.schedule(new TimerTask { + override def run(): Unit = { + connectNodes(nodes(0), node10) + timer.cancel() + } + }, Duration(5, TimeUnit.SECONDS).toMillis) + sendMessage(nodes(7), node10, 30) + + (nodes :+ node10).foreach(_.stop()) + + System.out.println("Test send delayed successful!") + } + + private def createNodes(count: Int): Seq[LocalNode] = { + val nodes = Await.result(Future.sequence((0 until count).map(createNode)), Duration.Inf) + nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}")) + nodes + } /** * Creates a new mesh with a predefined layout. @@ -68,8 +169,7 @@ object Main extends App { * @return List of [[LocalNode]]s, ordered from 0 to 8. */ private def createMesh(): Seq[LocalNode] = { - val nodes = Await.result(Future.sequence(0.to(8).map(createNode)), Duration.Inf) - sys.addShutdownHook(nodes.foreach(_.stop())) + val nodes = createNodes(9) connectNodes(nodes(0), nodes(1)) connectNodes(nodes(0), nodes(2)) @@ -82,7 +182,6 @@ object Main extends App { connectNodes(nodes(3), nodes(7)) connectNodes(nodes(0), nodes(8)) connectNodes(nodes(7), nodes(8)) - nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}")) nodes } @@ -119,7 +218,6 @@ object Main extends App { sendMessage(nodes(3), nodes(5)) sendMessage(nodes(4), nodes(6)) sendMessage(nodes(2), nodes(3)) - sendMessage(nodes(0), nodes(3)) sendMessage(nodes(3), nodes(6)) sendMessage(nodes(3), nodes(2)) } diff --git a/server/src/main/scala/com/nutomic/ensichat/server/Main.scala b/server/src/main/scala/com/nutomic/ensichat/server/Main.scala index 530bd59..2af3d3a 100644 --- a/server/src/main/scala/com/nutomic/ensichat/server/Main.scala +++ b/server/src/main/scala/com/nutomic/ensichat/server/Main.scala @@ -22,7 +22,7 @@ object Main extends App with CallbackInterface { private lazy val settings = new Settings(ConfigFile) private lazy val crypto = new Crypto(settings, KeyFolder) - private lazy val database = new Database(DatabaseFile, this) + private lazy val database = new Database(DatabaseFile, settings, this) private lazy val connectionHandler = new ConnectionHandler(settings, database, this, crypto, 7) init()