Skip to content

Commit

Permalink
DEX-818 Invalid asset balance on connection (#312)
Browse files Browse the repository at this point in the history
A possible fix for invalid balances. The situation wasn't reproduced.

Also contains bug fixes:
* Connections with subscriptions to the address's changes can affect each other and the last subscription can be cancelled without an user's will;
* Negative balances can be propagated to the client;
* Changed assets should not be forgotten during new subscriptions. So their changes will be sent in the next WebSocket tick;
* The full state is set only once during multiple simultaneous connections. Fixes possible issues with the stale balances;

And improvements:
* AddressActor could have multiple WebSocket schedules;
  • Loading branch information
vsuharnikov authored Jul 7, 2020
1 parent e42f17c commit 9bf1d38
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import com.wavesplatform.dex.test.matchers.DiffMatcherWithImplicits
import mouse.any._
import org.scalatest.concurrent.Eventually
import org.scalatest.matchers.should.Matchers
import org.scalatest.{BeforeAndAfterAll, Suite}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite}

import scala.concurrent.duration._

trait HasWebSockets extends BeforeAndAfterAll with HasJwt with WsConnectionOps with WsMessageOps {
trait HasWebSockets extends BeforeAndAfterAll with BeforeAndAfterEach with HasJwt with WsConnectionOps with WsMessageOps {
_: Suite with Eventually with Matchers with DiffMatcherWithImplicits with PredefinedAssets =>

implicit protected val system: ActorSystem = ActorSystem()
Expand Down Expand Up @@ -80,15 +80,19 @@ trait HasWebSockets extends BeforeAndAfterAll with HasJwt with WsConnectionOps w
c.clearMessages()
}

protected def cleanupWebSockets(): Unit = {
if (!knownWsConnections.isEmpty) {
knownWsConnections.forEach { _.close() }
materializer.shutdown()
}
protected def cleanupWebSockets(): Unit = if (!knownWsConnections.isEmpty) {
knownWsConnections.forEach { _.close() }
knownWsConnections.clear()
}

override protected def afterEach(): Unit = {
super.afterEach()
cleanupWebSockets()
}

override def afterAll(): Unit = {
super.afterAll()
cleanupWebSockets()
materializer.shutdown()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import cats.instances.try_._
import com.dimafeng.testcontainers.GenericContainer
import com.typesafe.config.Config
import com.wavesplatform.dex.domain.utils.ScorexLogging
import com.wavesplatform.dex.it.api.HasWaitReady
import com.wavesplatform.dex.it.api.node.NodeApi
import com.wavesplatform.dex.it.cache.CachedData
import com.wavesplatform.dex.it.fp
Expand Down Expand Up @@ -39,8 +38,8 @@ final case class WavesNodeContainer(override val internalIp: String, underlying:

def grpcApiTarget: String = s"${grpcApiAddress.getHostName}:${grpcApiAddress.getPort}"

override def api: NodeApi[Id] = fp.sync { NodeApi[Try](apiKey, cachedRestApiAddress.get()) }
override def asyncApi: HasWaitReady[Future] = NodeApi[Future](apiKey, cachedRestApiAddress.get())
override def api: NodeApi[Id] = fp.sync { NodeApi[Try](apiKey, cachedRestApiAddress.get()) }
override def asyncApi: NodeApi[Future] = NodeApi[Future](apiKey, cachedRestApiAddress.get())

override def invalidateCaches(): Unit = {
super.invalidateCaches()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import com.wavesplatform.dex.domain.asset.Asset.Waves
import com.wavesplatform.dex.domain.model.Denormalization
import com.wavesplatform.dex.domain.order.OrderType.{BUY, SELL}
import com.wavesplatform.dex.error.SubscriptionsLimitReached
import com.wavesplatform.dex.it.api.responses.dex.OrderStatus
import com.wavesplatform.dex.it.waves.MkWavesEntities.IssueResults
import com.wavesplatform.dex.model.{LimitOrder, MarketOrder, OrderStatus}
import com.wavesplatform.dex.model.{LimitOrder, MarketOrder}
import com.wavesplatform.it.WsSuiteBase
import org.scalatest.prop.TableDrivenPropertyChecks

import scala.concurrent.duration._
import scala.concurrent.{Await, Future}

class WsAddressStreamTestSuite extends WsSuiteBase with TableDrivenPropertyChecks {

Expand Down Expand Up @@ -418,7 +420,7 @@ class WsAddressStreamTestSuite extends WsSuiteBase with TableDrivenPropertyCheck
}
}

"DEX-817 Invalid balances after connection (leasing)" in {
"DEX-817 Invalid WAVES balance after connection (leasing)" in {
val bobWavesBalanceBefore = dex1.api.tradableBalance(bob, wavesBtcPair)(Waves)

dex1.stopWithoutRemove()
Expand All @@ -439,5 +441,57 @@ class WsAddressStreamTestSuite extends WsSuiteBase with TableDrivenPropertyCheck

broadcastAndAwait(mkLeaseCancel(bob, leaseTx.getId))
}

"DEX-818" - {
"Connections can affect each other" in {
val wscs = (1 to 10).map(_ => mkWsAddressConnection(bob))
val mainWsc = mkWsAddressConnection(bob)

markup("Multiple orders")
val now = System.currentTimeMillis()
val orders = (1 to 50).map { i =>
mkOrderDP(bob, wavesBtcPair, BUY, 1.waves, 0.00012, ts = now + i)
}

Await.result(Future.traverse(orders)(dex1.asyncApi.place), 1.minute)
dex1.api.cancelAll(bob)

wscs.par.foreach(_.close())
Thread.sleep(3000)
mainWsc.clearMessages()

markup("A new order")
placeAndAwaitAtDex(mkOrderDP(bob, wavesBtcPair, BUY, 2.waves, 0.00029))

eventually {
mainWsc.receiveAtLeastN[WsAddressState](1)
}
mainWsc.clearMessages()
}

"Negative balances" in {
val carol = mkAccountWithBalance(5.waves -> Waves)
val wsc = mkWsAddressConnection(carol)

val now = System.currentTimeMillis()
val txs = (1 to 2).map { i =>
mkTransfer(carol, alice, 5.waves - minFee, Waves, minFee, timestamp = now + i)
}
val simulation = Future.traverse(txs)(wavesNode1.asyncApi.broadcast(_))
Await.result(simulation, 1.minute)
wavesNode1.api.waitForHeightArise()

wsc.balanceChanges.zipWithIndex.foreach {
case (changes, i) =>
changes.foreach {
case (asset, balance) =>
withClue(s"$i: $asset -> $balance: ") {
balance.tradable should be >= 0.0
balance.reserved should be >= 0.0
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import com.wavesplatform.dex.settings.{DenormalizedMatchingRule, OrderRestrictio
import com.wavesplatform.it.WsSuiteBase

import scala.collection.immutable.TreeMap
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, Future}

class WsOrderBookStreamTestSuite extends WsSuiteBase {
Expand Down Expand Up @@ -455,15 +456,14 @@ class WsOrderBookStreamTestSuite extends WsSuiteBase {

"Bugs" - {
"DEX-814 Connections can affect each other" in {
val wscs = (1 to 10).map(_ => mkWsOrderBookConnection(wavesBtcPair, dex1))
val wscs = (1 to 10).map(_ => mkWsOrderBookConnection(wavesBtcPair, dex1))
val mainWsc = mkWsOrderBookConnection(wavesBtcPair, dex1)

markup("Multiple orders")
val orders = (1 to 50).map { i =>
mkOrderDP(carol, wavesBtcPair, BUY, 1.waves + i, 0.00012)
}

import scala.concurrent.duration.DurationInt
Await.result(Future.traverse(orders)(dex1.asyncApi.place), 1.minute)
dex1.api.cancelAll(carol)

Expand All @@ -481,10 +481,10 @@ class WsOrderBookStreamTestSuite extends WsSuiteBase {
assetPair = wavesBtcPair,
asks = TreeMap.empty,
bids = TreeMap(0.00029d -> 2d),
lastTrade = None,
lastTrade = none,
updateId = 0,
timestamp = buffer.last.timestamp,
settings = None
settings = none
)
)
}
Expand Down
44 changes: 29 additions & 15 deletions dex/src/main/scala/com/wavesplatform/dex/AddressActor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class AddressActor(owner: Address,
// Saves from cases when a client does multiple requests with the same order
private val failedPlacements = MutableSet.empty[Order.Id]

private var addressWsMutableState = AddressWsMutableState.empty(owner)
private var addressWsMutableState = AddressWsMutableState.empty(owner)
private var wsSendSchedule: Cancellable = Cancellable.alreadyCancelled

override def receive: Receive = {
case command: Command.PlaceOrder =>
Expand Down Expand Up @@ -139,7 +140,7 @@ class AddressActor(owner: Address,

case msg: Message.BalanceChanged =>
if (addressWsMutableState.hasActiveSubscriptions) {
addressWsMutableState = addressWsMutableState.putSpendableAssets(msg.allChanges.keySet)
addressWsMutableState = addressWsMutableState.putSpendableAssets(msg.changedAssets)
}

val toCancel = getOrdersToCancel(msg.changesForAudit).filterNot(ao => isCancelling(ao.order.id()))
Expand All @@ -156,7 +157,7 @@ class AddressActor(owner: Address,
}

case Query.GetReservedBalance => sender ! Reply.Balance(openVolume.filter(_._2 > 0))
case Query.GetTradableBalance(forAssets) => getTradableBalance(forAssets).map(xs => Reply.Balance(xs.filter(_._2 > 0))).pipeTo(sender)
case Query.GetTradableBalance(forAssets) => getTradableBalance(forAssets).map(Reply.Balance).pipeTo(sender)

case Query.GetOrderStatus(orderId) => sender ! activeOrders.get(orderId).fold[OrderStatus](orderDB.status(orderId))(_.status)

Expand Down Expand Up @@ -287,43 +288,52 @@ class AddressActor(owner: Address,
addressWsMutableState = addressWsMutableState.removeSubscription(client)
context.unwatch(client)

// Received a snapshot for pending connections
case SpendableBalancesActor.Reply.GetSnapshot(allAssetsSpendableBalance) =>
allAssetsSpendableBalance match {
case Right(spendableBalance) =>
addressWsMutableState.sendSnapshot(
balances = mkWsBalances(spendableBalance),
orders = activeOrders.values.map(WsOrder.fromDomain(_))(collection.breakOut),
)
if (!addressWsMutableState.hasActiveSubscriptions) scheduleNextDiffSending
if (!addressWsMutableState.hasActiveSubscriptions) scheduleNextDiffSending()
addressWsMutableState = addressWsMutableState.flushPendingSubscriptions()
case Left(matcherError) =>
addressWsMutableState.pendingSubscription.foreach { _.unsafeUpcast[WsServerMessage] ! WsError.from(matcherError, time.correctedTime()) }
addressWsMutableState = addressWsMutableState.copy(pendingSubscription = Set.empty)
}

// It is time to send updates to clients. This block of code asks balances
case WsCommand.PrepareDiffForWsSubscribers =>
if (addressWsMutableState.hasActiveSubscriptions) {
if (addressWsMutableState.hasChanges) {
spendableBalancesActor ! SpendableBalancesActor.Query.GetState(owner, addressWsMutableState.getAllChangedAssets)
} else scheduleNextDiffSending
// We asked balances for current changedAssets and clean it here,
// because there are could be new changes between sent Query.GetState and received Reply.GetState.
addressWsMutableState = addressWsMutableState.cleanBalanceChanges()
} else scheduleNextDiffSending()
}

// It is time to send updates to clients. This block of code sends balances
case SpendableBalancesActor.Reply.GetState(spendableBalances) =>
if (addressWsMutableState.hasActiveSubscriptions) {
addressWsMutableState = addressWsMutableState.sendDiffs(
addressWsMutableState = if (addressWsMutableState.hasActiveSubscriptions) {
scheduleNextDiffSending()
addressWsMutableState.sendDiffs(
balances = mkWsBalances(spendableBalances),
orders = addressWsMutableState.getAllOrderChanges
)
scheduleNextDiffSending
}
} else if (addressWsMutableState.pendingSubscription.isEmpty) {
addressWsMutableState.cleanBalanceChanges() // There are neither active, nor pending connections
} else addressWsMutableState

addressWsMutableState = addressWsMutableState.cleanChanges()
addressWsMutableState = addressWsMutableState.cleanOrderChanges()

case classic.Terminated(wsSource) => addressWsMutableState = addressWsMutableState.removeSubscription(wsSource)
}

private def scheduleNextDiffSending: Cancellable = {
context.system.scheduler.scheduleOnce(settings.wsMessagesInterval, self, WsCommand.PrepareDiffForWsSubscribers)
private def scheduleNextDiffSending(): Unit = {
wsSendSchedule.cancel()
wsSendSchedule = context.system.scheduler.scheduleOnce(settings.wsMessagesInterval, self, WsCommand.PrepareDiffForWsSubscribers)
}

private def denormalizedBalanceValue(asset: Asset, decimals: Int)(balanceSource: Map[Asset, Long]): Double =
Expand Down Expand Up @@ -384,7 +394,7 @@ class AddressActor(owner: Address,
spendableBalancesActor
.ask(SpendableBalancesActor.Query.GetState(owner, forAssets))(5.seconds, self) // TODO replace ask pattern by better solution
.mapTo[SpendableBalancesActor.Reply.GetState]
.map(xs => (xs.state |-| openVolume.filterKeys(forAssets)).withDefaultValue(0L))
.map(xs => (xs.state |-| openVolume.filterKeys(forAssets)).filter(_._2 > 0).withDefaultValue(0L))
}

private def scheduleExpiration(order: Order): Unit = if (enableSchedules && !expiration.contains(order.id())) {
Expand Down Expand Up @@ -537,6 +547,11 @@ class AddressActor(owner: Address,
ao <- activeOrders.values
if ao.isLimit && maybePair.forall(_ == ao.order.assetPair)
} yield ao

override def preRestart(reason: Throwable, message: Option[Any]): Unit = {
log.error(s"Failed on $message", reason)
super.preRestart(reason, message)
}
}

object AddressActor {
Expand All @@ -561,8 +576,7 @@ object AddressActor {

sealed trait Message
object Message {
// values of map allChanges can be used in future for tracking balances in AddressActor
case class BalanceChanged(allChanges: Map[Asset, Long], changesForAudit: Map[Asset, Long]) extends Message
case class BalanceChanged(changedAssets: Set[Asset], changesForAudit: Map[Asset, Long]) extends Message
}

sealed trait Query extends Message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ case class AddressWsMutableState(address: Address,
copy(activeSubscription = activeSubscription ++ pendingSubscription.iterator.map(_ -> 0L), pendingSubscription = Set.empty)

def removeSubscription(subscriber: ActorRef[WsAddressState]): AddressWsMutableState = {
if (activeSubscription.size == 1) copy(activeSubscription = Map.empty).cleanChanges()
else copy(activeSubscription = activeSubscription - subscriber)
val updated = copy(activeSubscription = activeSubscription - subscriber)
if (updated.activeSubscription.isEmpty) updated.cleanAllChanges()
else updated
}

def putReservedAssets(diff: Set[Asset]): AddressWsMutableState = copy(changedReservableAssets = changedReservableAssets ++ diff)
Expand Down Expand Up @@ -78,7 +79,10 @@ case class AddressWsMutableState(address: Address,
}
)

def cleanChanges(): AddressWsMutableState = copy(changedSpendableAssets = Set.empty, changedReservableAssets = Set.empty, ordersChanges = Map.empty)
def cleanAllChanges(): AddressWsMutableState =
copy(changedSpendableAssets = Set.empty, changedReservableAssets = Set.empty, ordersChanges = Map.empty)
def cleanOrderChanges(): AddressWsMutableState = copy(ordersChanges = Map.empty)
def cleanBalanceChanges(): AddressWsMutableState = copy(changedSpendableAssets = Set.empty, changedReservableAssets = Set.empty)
}

object AddressWsMutableState {
Expand Down
Loading

0 comments on commit 9bf1d38

Please sign in to comment.