Skip to content

Commit

Permalink
Fixed channel interruptions.
Browse files Browse the repository at this point in the history
  • Loading branch information
klemens-morgenstern committed Feb 3, 2025
1 parent 74f5546 commit 61cf72a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 22 deletions.
36 changes: 24 additions & 12 deletions include/boost/cobalt/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ struct channel

void interrupt_await()
{
this->cancelled = true;
if (awaited_from)
awaited_from.release().resume();
if (!direct)
{
this->cancelled = true;
if (this->awaited_from)
this->awaited_from.release().resume();
}
}

struct cancel_impl;
Expand Down Expand Up @@ -127,9 +130,12 @@ struct channel

void interrupt_await()
{
this->cancelled = true;
if (awaited_from)
awaited_from.release().resume();
if (!direct)
{
this->cancelled = true;
if (this->awaited_from)
this->awaited_from.release().resume();
}
}

struct cancel_impl;
Expand Down Expand Up @@ -250,9 +256,12 @@ struct channel<void>

void interrupt_await()
{
this->cancelled = true;
if (awaited_from)
awaited_from.release().resume();
if (!direct)
{
this->cancelled = true;
if (this->awaited_from)
this->awaited_from.release().resume();
}
}

struct cancel_impl;
Expand Down Expand Up @@ -284,9 +293,12 @@ struct channel<void>

void interrupt_await()
{
this->cancelled = true;
if (awaited_from)
awaited_from.release().resume();
if (!direct)
{
cancelled = true;
if (this->awaited_from)
this->awaited_from.release().resume();
}
}

struct cancel_impl;
Expand Down
29 changes: 19 additions & 10 deletions include/boost/cobalt/impl/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ std::coroutine_handle<void> channel<T>::read_op::await_suspend(std::coroutine_ha
{
cancel_slot.clear();
auto & op = chn->write_queue_.front();
op.transactional_unlink();
// transactional_unlink can interrupt or cancel `op` through `race`, so we need to check.
op.direct = true;
if constexpr (std::is_copy_constructible_v<T>)
{
Expand All @@ -127,7 +127,11 @@ std::coroutine_handle<void> channel<T>::read_op::await_suspend(std::coroutine_ha
}
else
direct = std::move(*op.ref);

op.transactional_unlink();
BOOST_ASSERT(op.awaited_from);
BOOST_ASSERT(awaited_from);

asio::post(chn->executor_, std::move(awaited_from));
return op.awaited_from.release();
}
Expand Down Expand Up @@ -207,7 +211,6 @@ std::coroutine_handle<void> channel<T>::write_op::await_suspend(std::coroutine_h
if constexpr (requires (Promise p) {p.begin_transaction();})
begin_transaction = +[](void * p){std::coroutine_handle<Promise>::from_address(p).promise().begin_transaction();};

// currently nothing to read
if (chn->read_queue_.empty())
{
chn->write_queue_.push_back(*this);
Expand All @@ -217,7 +220,6 @@ std::coroutine_handle<void> channel<T>::write_op::await_suspend(std::coroutine_h
{
cancel_slot.clear();
auto & op = chn->read_queue_.front();
op.transactional_unlink();
if constexpr (std::is_copy_constructible_v<T>)
{
if (ref.index() == 0)
Expand All @@ -228,8 +230,11 @@ std::coroutine_handle<void> channel<T>::write_op::await_suspend(std::coroutine_h
else
op.direct.emplace(std::move(*ref));

BOOST_ASSERT(op.awaited_from);
direct = true;
op.transactional_unlink();

BOOST_ASSERT(op.awaited_from);
BOOST_ASSERT(awaited_from);
asio::post(chn->executor_, std::move(awaited_from));

return op.awaited_from.release();
Expand All @@ -256,7 +261,6 @@ system::result<void> channel<T>::write_op::await_resume(const struct as_result_
if (cancelled)
boost::throw_exception(system::system_error(asio::error::operation_aborted), loc);


if (!direct)
{
BOOST_ASSERT(!chn->buffer_.full());
Expand Down Expand Up @@ -333,10 +337,12 @@ std::coroutine_handle<void> channel<void>::read_op::await_suspend(std::coroutine
{
cancel_slot.clear();
auto & op = chn->write_queue_.front();
op.unlink();
op.direct = true;
BOOST_ASSERT(op.awaited_from);
direct = true;
op.transactional_unlink();

BOOST_ASSERT(op.awaited_from);
BOOST_ASSERT(awaited_from);
asio::post(chn->executor_, std::move(awaited_from));
return op.awaited_from.release();
}
Expand Down Expand Up @@ -364,10 +370,13 @@ std::coroutine_handle<void> channel<void>::write_op::await_suspend(std::coroutin
{
cancel_slot.clear();
auto & op = chn->read_queue_.front();
op.unlink();
op.direct = true;
BOOST_ASSERT(op.awaited_from);
op.direct = true; // let interrupt_await know that we'll be resuming it!
direct = true;
op.transactional_unlink();

BOOST_ASSERT(op.awaited_from);
BOOST_ASSERT(awaited_from);

asio::post(chn->executor_, std::move(awaited_from));
return op.awaited_from.release();
}
Expand Down
46 changes: 46 additions & 0 deletions test/channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,51 @@ CO_TEST_CASE(any)
}


CO_TEST_CASE(interrupt_)
{
cobalt::channel<int> c;
auto lr = co_await cobalt::left_race(c.write(42), c.read());
BOOST_CHECK(lr.index() == 0);
auto rl = co_await cobalt::left_race(c.read(), c.write(42));
BOOST_CHECK(rl.index() == 0);
}

CO_TEST_CASE(interrupt_void)
{
cobalt::channel<void> c;
auto lr = co_await cobalt::left_race(c.write(), c.read());
BOOST_CHECK(lr == 0);
auto rl = co_await cobalt::left_race(c.read(), c.write());
BOOST_CHECK(rl == 0);
}

CO_TEST_CASE(data_loss)
{
cobalt::channel<int> c1 {10};
cobalt::channel<int> c2 {10};
cobalt::channel<int> c3 {10};
for (int i = 0; i < 10; i++)
{
co_await c1.write(i);
co_await c2.write(1000 + i);
}
int i1 = 0;
int i2 = 1000;
std::default_random_engine g(0xDEADBBEF);

while (i1 < 10)
{
auto res = co_await cobalt::race(g, c1.read(), c2.read(), c3.read());
switch (res.index())
{
case 0:
BOOST_REQUIRE_EQUAL(boost::variant2::get<0>(res), i1++);
break;
case 1:
BOOST_REQUIRE_EQUAL(boost::variant2::get<1>(res), i2++);
break;
}
}
}

}

0 comments on commit 61cf72a

Please sign in to comment.