Allow adding other device by address, auto request public key

This commit is contained in:
Felix Ableitner 2016-09-12 18:51:53 +02:00
parent 127e4c9ff2
commit f3ec28fef8
9 changed files with 246 additions and 19 deletions

View file

@ -282,6 +282,32 @@ Address is the address that is no longer reachable.
SeqNum is the sequence number of the route that is no longer available SeqNum is the sequence number of the route that is no longer available
(if known). Otherwise, set TargSeqNum = -1. This field is signed. (if known). Otherwise, set TargSeqNum = -1. This field is signed.
### PublicKeyRequest (Protocol-Type = 5)
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Contains an address for which the sender wants the corresponding public
key.
### PublicKeyReply (Protocol-Type = 6)
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Key Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/ /
\ Key (variable length) \
/ /
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Contains a node's public key in binary form. Sent in reply to a
PublicKeyRequest.
Content Messages Content Messages
---------------- ----------------

View file

@ -127,7 +127,7 @@ class ConnectionsActivity extends EnsichatActivity with OnItemClickListener {
.setMessage(getString(R.string.dialog_add_contact, user.name)) .setMessage(getString(R.string.dialog_add_contact, user.name))
.setPositiveButton(android.R.string.yes, new OnClickListener { .setPositiveButton(android.R.string.yes, new OnClickListener {
override def onClick(dialog: DialogInterface, which: Int): Unit = { override def onClick(dialog: DialogInterface, which: Int): Unit = {
database.get.addContact(user) service.get.addContact(user)
Toast.makeText(ConnectionsActivity.this, R.string.toast_contact_added, Toast.LENGTH_SHORT) Toast.makeText(ConnectionsActivity.this, R.string.toast_contact_added, Toast.LENGTH_SHORT)
.show() .show()
} }

View file

@ -42,12 +42,17 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
private lazy val messageBuffer = new MessageBuffer(crypto.localAddress, requestRoute) private lazy val messageBuffer = new MessageBuffer(crypto.localAddress, requestRoute)
/**
* Messages which we couldn't verify yet because we don't have the sender's public key.
*/
private var unverifiedMessages = Set[Message]()
/** /**
* Holds all known users. * Holds all known users.
* *
* This is for user names that were received during runtime, and is not persistent. * This is for user names that were received during runtime, and is not persistent.
*/ */
private var knownUsers = Set[util.User]() private var knownUsers = Set[User]()
/** /**
* Generates keys and starts Bluetooth interface. * Generates keys and starts Bluetooth interface.
@ -86,16 +91,28 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
assert(body.contentType != -1) assert(body.contentType != -1)
FutureHelper { FutureHelper {
val messageId = settings.get("message_id", 0L) val messageId = settings.get("message_id", 0L)
val header = new ContentHeader(crypto.localAddress, target, seqNumGenerator.next(), val header = ContentHeader(crypto.localAddress, target, seqNumGenerator.next(),
body.contentType, Some(messageId), Some(DateTime.now), AbstractHeader.InitialForwardingTokens) body.contentType, Some(messageId), Some(DateTime.now), AbstractHeader.InitialForwardingTokens)
settings.put("message_id", messageId + 1) settings.put("message_id", messageId + 1)
val msg = new Message(header, body) val msg = new Message(header, body)
onNewMessage(msg)
if (crypto.havePublicKey(target)) {
val encrypted = crypto.encryptAndSign(msg) val encrypted = crypto.encryptAndSign(msg)
router.forwardMessage(encrypted) router.forwardMessage(encrypted)
forwardMessageToRelays(encrypted) forwardMessageToRelays(encrypted)
onNewMessage(msg)
} }
else {
logger.info(s"Public key missing for $target, buffering message and sending key request")
requestPublicKey(target)
}
}
}
private def requestPublicKey(address: Address): Unit = {
val header = MessageHeader(PublicKeyRequest.Type, crypto.localAddress, Address.Broadcast, seqNumGenerator.next(), 0)
val msg = new Message(header, PublicKeyRequest(address))
router.forwardMessage(crypto.sign(msg))
} }
private def requestRoute(target: Address): Unit = { private def requestRoute(target: Address): Unit = {
@ -205,6 +222,38 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
.foreach(routeError(_, None)) .foreach(routeError(_, None))
} }
} }
return
case pkr: PublicKeyRequest =>
if (crypto.havePublicKey(pkr.address)) {
val header = MessageHeader(PublicKeyReply.Type, crypto.localAddress, msg.header.origin, seqNumGenerator.next(), 0)
val msg2 = new Message(header, PublicKeyReply(crypto.getPublicKey(pkr.address)))
router.forwardMessage(crypto.sign(msg2), Option(previousHop))
}
else {
router.forwardMessage(msg)
}
return
case pkr: PublicKeyReply =>
if (msg.header.target != crypto.localAddress) {
router.forwardMessage(msg)
return
}
val address = crypto.calculateAddress(pkr.key)
if (crypto.havePublicKey(address))
return
logger.info(s"Received public key for $address, resending and decrypting messages")
crypto.addPublicKey(address, pkr.key)
database.getMessages(address)
.filter(_.header.target == address)
.foreach{ m =>
sendTo(address, m.body)
}
val current = unverifiedMessages
.filter(_.header.origin == address)
current.foreach(decryptMessage)
unverifiedMessages --= current
return
case _ => case _ =>
} }
@ -214,6 +263,17 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
return return
} }
if (!crypto.havePublicKey(msg.header.origin)) {
logger.info(s"Received message from ${msg.header.origin} but don't have public key, buffering")
unverifiedMessages += msg
requestPublicKey(msg.header.origin)
return
}
decryptMessage(msg)
}
private def decryptMessage(msg: Message): Unit = {
val plainMsg = val plainMsg =
try { try {
if (!crypto.verify(msg)) { if (!crypto.verify(msg)) {
@ -241,7 +301,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
if (plainMsg.body.contentType == Text.Type) { if (plainMsg.body.contentType == Text.Type) {
logger.trace(s"Sending confirmation for $plainMsg") logger.trace(s"Sending confirmation for $plainMsg")
sendTo(plainMsg.header.origin, new messages.body.MessageReceived(plainMsg.header.messageId.get)) sendTo(plainMsg.header.origin, new MessageReceived(plainMsg.header.messageId.get))
} }
onNewMessage(plainMsg) onNewMessage(plainMsg)
@ -280,13 +340,13 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
*/ */
private def onNewMessage(msg: Message): Unit = msg.body match { private def onNewMessage(msg: Message): Unit = msg.body match {
case ui: UserInfo => case ui: UserInfo =>
val contact = new util.User(msg.header.origin, ui.name, ui.status) val contact = User(msg.header.origin, ui.name, ui.status)
knownUsers += contact knownUsers += contact
if (database.getContact(msg.header.origin).nonEmpty) if (database.getContact(msg.header.origin).nonEmpty)
database.updateContact(contact) database.updateContact(contact)
callbacks.onConnectionsChanged() callbacks.onConnectionsChanged()
case mr: messages.body.MessageReceived => case mr: MessageReceived =>
database.setMessageConfirmed(mr.messageId) database.setMessageConfirmed(mr.messageId)
case _ => case _ =>
val origin = msg.header.origin val origin = msg.header.origin
@ -339,7 +399,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
else else
logger.info("Node " + sender + " connected") logger.info("Node " + sender + " connected")
sendTo(sender, new UserInfo(settings.get(SettingsInterface.KeyUserName, ""), sendTo(sender, UserInfo(settings.get(SettingsInterface.KeyUserName, ""),
settings.get(SettingsInterface.KeyUserStatus, ""))) settings.get(SettingsInterface.KeyUserStatus, "")))
callbacks.onConnectionsChanged() callbacks.onConnectionsChanged()
resendMissingRouteMessages() resendMissingRouteMessages()
@ -372,7 +432,7 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
def getUser(address: Address) = def getUser(address: Address) =
allKnownUsers() allKnownUsers()
.find(_.address == address) .find(_.address == address)
.getOrElse(new util.User(address, address.toString(), "")) .getOrElse(User(address, address.toString(), ""))
/** /**
* This method should be called when the local device's internet connection has changed in any way. * This method should be called when the local device's internet connection has changed in any way.
@ -382,4 +442,12 @@ final class ConnectionHandler(settings: SettingsInterface, database: Database,
.find(_.isInstanceOf[InternetInterface]) .find(_.isInstanceOf[InternetInterface])
.foreach(_.asInstanceOf[InternetInterface].connectionChanged()) .foreach(_.asInstanceOf[InternetInterface].connectionChanged())
} }
def addContact(user: User): Unit = {
database.addContact(user)
if (!crypto.havePublicKey(user.address)) {
requestPublicKey(user.address)
}
}
} }

View file

@ -50,11 +50,13 @@ object Message {
val body = val body =
header.protocolType match { header.protocolType match {
case messages.body.ConnectionInfo.Type => messages.body.ConnectionInfo.read(remaining) case ConnectionInfo.Type => ConnectionInfo.read(remaining)
case RouteRequest.Type => RouteRequest.read(remaining) case RouteRequest.Type => RouteRequest.read(remaining)
case RouteReply.Type => RouteReply.read(remaining) case RouteReply.Type => RouteReply.read(remaining)
case RouteError.Type => RouteError.read(remaining) case RouteError.Type => RouteError.read(remaining)
case _ => new EncryptedBody(remaining) case PublicKeyRequest.Type => PublicKeyRequest.read(remaining)
case PublicKeyReply.Type => PublicKeyReply.read(remaining)
case _ => EncryptedBody(remaining)
} }
new Message(header, crypto, body) new Message(header, crypto, body)

View file

@ -0,0 +1,45 @@
package com.nutomic.ensichat.core.messages.body
import java.nio.ByteBuffer
import java.security.spec.X509EncodedKeySpec
import java.security.{KeyFactory, PublicKey}
import com.nutomic.ensichat.core.util.BufferUtils
import com.nutomic.ensichat.core.util.Crypto
object PublicKeyReply {
val Type = 6
/**
* Constructs [[ConnectionInfo]] instance from byte array.
*/
def read(array: Array[Byte]): PublicKeyReply = {
val b = ByteBuffer.wrap(array)
val length = BufferUtils.getUnsignedInt(b).toInt
val encoded = new Array[Byte](length)
b.get(encoded, 0, length)
val factory = KeyFactory.getInstance(Crypto.PublicKeyAlgorithm)
val key = factory.generatePublic(new X509EncodedKeySpec(encoded))
new PublicKeyReply(key)
}
}
case class PublicKeyReply(key: PublicKey) extends MessageBody {
override def protocolType = PublicKeyRequest.Type
override def contentType = -1
override def write: Array[Byte] = {
val b = ByteBuffer.allocate(length)
BufferUtils.putUnsignedInt(b, key.getEncoded.length)
b.put(key.getEncoded)
b.array()
}
override def length = 4 + key.getEncoded.length
}

View file

@ -0,0 +1,50 @@
package com.nutomic.ensichat.core.messages.body
import java.nio.ByteBuffer
import com.nutomic.ensichat.core.routing.Address
import com.nutomic.ensichat.core.util.BufferUtils
object PublicKeyRequest {
val Type = 5
/**
* Constructs [[Text]] instance from byte array.
*/
def read(array: Array[Byte]): PublicKeyRequest = {
val b = ByteBuffer.wrap(array)
val length = BufferUtils.getUnsignedInt(b).toInt
val bytes = new Array[Byte](length)
b.get(bytes, 0, length)
new PublicKeyRequest(new Address(bytes))
}
}
case class PublicKeyRequest(address: Address) extends MessageBody {
require(address != Address.Broadcast, "")
require(address != Address.Null, "")
override def protocolType = PublicKeyRequest.Type
override def contentType = -1
override def write: Array[Byte] = {
val b = ByteBuffer.allocate(length)
val bytes = address.bytes
BufferUtils.putUnsignedInt(b, bytes.length)
b.put(bytes)
b.array()
}
override def equals(a: Any): Boolean = a match {
case o: PublicKeyRequest => address == o.address
case _ => false
}
override def length = 4 + address.bytes.length
}

View file

@ -67,6 +67,7 @@ private[core] class Router(routesInfo: LocalRoutesInfo, send: (Address, Message)
send(a, incHopCount(msg)) send(a, incHopCount(msg))
markMessageSeen((msg.header.origin, msg.header.seqNum)) markMessageSeen((msg.header.origin, msg.header.seqNum))
case None => case None =>
if (msg.header.isInstanceOf[ContentHeader])
noRouteFound(msg) noRouteFound(msg)
} }
} }

View file

@ -160,7 +160,7 @@ class Database(path: File, settings: SettingsInterface, callbackInterface: Callb
/** /**
* Inserts the user as a new contact. * Inserts the user as a new contact.
*/ */
def addContact(contact: User): Unit = { private[core] def addContact(contact: User): Unit = {
Await.result(db.run(contacts += contact), Duration.Inf) Await.result(db.run(contacts += contact), Duration.Inf)
callbackInterface.onContactsUpdated() callbackInterface.onContactsUpdated()
} }

View file

@ -22,6 +22,7 @@ import scalax.file.Path
*/ */
object Main extends App { object Main extends App {
/*
testNeighborSending() testNeighborSending()
testMeshMessageSending() testMeshMessageSending()
testIndirectRelay() testIndirectRelay()
@ -30,6 +31,8 @@ object Main extends App {
testSendDelayed() testSendDelayed()
testRouteChange() testRouteChange()
testMessageConfirmation() testMessageConfirmation()
*/
testKeyRequest()
private def testNeighborSending(): Unit = { private def testNeighborSending(): Unit = {
val node1 = Await.result(createNode(1), Duration.Inf) val node1 = Await.result(createNode(1), Duration.Inf)
@ -169,6 +172,38 @@ object Main extends App {
nodes.foreach(_.stop()) nodes.foreach(_.stop())
} }
private def testKeyRequest(): Unit = {
val nodes = createNodes(4)
connectNodes(nodes(0), nodes(1))
connectNodes(nodes(1), nodes(2))
connectNodes(nodes(2), nodes(3))
val origin = nodes(0)
val target = nodes(3)
System.out.println(s"sendMessage(${origin.index}, ${target.index})")
val text = s"${origin.index} to ${target.index}"
origin.connectionHandler.sendTo(target.crypto.localAddress, new Text(text))
val latch = new CountDownLatch(1)
Future {
val exists =
target.eventQueue.toStream.exists { event =>
if (event._1 != LocalNode.EventType.MessageReceived)
false
else {
event._2.get.body match {
case t: Text => t.text == text
case _ => false
}
}
}
assert(exists, s"message from ${origin.index} did not arrive at ${target.index}")
latch.countDown()
}
assert(latch.await(3, TimeUnit.SECONDS))
}
private def createNodes(count: Int): Seq[LocalNode] = { private def createNodes(count: Int): Seq[LocalNode] = {
val nodes = Await.result(Future.sequence((0 until count).map(createNode)), Duration.Inf) val nodes = Await.result(Future.sequence((0 until count).map(createNode)), Duration.Inf)
nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}")) nodes.foreach(n => System.out.println(s"Node ${n.index} has address ${n.crypto.localAddress}"))