ZIO STM: CountDownLatch in (effectively) two lines
Inspired by Wix’ Amitay Horwitz’ post on Creating a dead simple CountDownLatch with ZIO, here’s a CountDownLatch
in (effectively) two lines by using ZIO STM.
The code
First the code—we’ll dissect it afterwards
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
package dev.fsvehla
import zio._
import zio.stm._
final class CountDownLatch private (count: TRef[Int]) {
val countDown: UIO[Unit] =
count.update(_ - 1).commit.unit
val await: UIO[Unit] =
count.get.collect { case n if n <= 0 => () }.commit
}
object CountDownLatch {
def make(count: Int): UIO[CountDownLatch] =
TRef.make(count).map(r => new CountDownLatch(r)).commit
}
And here is a a test suite—using ZIO test—yielding a proper code:test ratio.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package dev.fsvehla
import zio.test.Assertion._
import zio.test._
import zio.Schedule
//noinspection TypeAnnotation
object CountDownLatchSpec extends DefaultRunnableSpec {
def spec =
suite("CountDownLatch")(
testM("notifies fibers that wait on completion") {
for {
latch <- CountDownLatch.make(5)
reader1Fiber <- latch.await.fork
reader2Fiber <- latch.await.fork
_ <- latch.countDown.repeat(Schedule.recurs(4))
_ <- reader1Fiber.join
_ <- reader2Fiber.join
} yield {
assertCompletes
}
},
testM("notifies new fibers on completion") {
for {
latch <- CountDownLatch.make(5)
_ <- latch.countDown.repeat(Schedule.recurs(4))
_ <- latch.await
} yield {
assertCompletes
}
}
)
}
Note well that ZIO.repeat(Schedule.recurs(4))
means that the Schedule
is repeated four times in addition to the initial invocation.
So, let dig into the code, starting with creation of the CountDownLatch
…
CountDownLatch.countDown
1
2
def make(count: Int): UIO[CountDownLatch] =
TRef.make(count).map(r => new CountDownLatch(r)).commit
First of all, the signature tells us that the creation of the CountDownLatch
is effectful and can’t fail.
UIO[A]
is a type alias for ZIO[Any, Nothing, A]
, which reads like this: has no requirements, is infallible, and will return A.
We’re creating a TRef
- a transactional reference, containing the passed in count, and create a new instance of our CountDownLatch
.
countDown
1
2
val countDown: UIO[Unit] =
count.update(_ - 1).commit.unit
This code is nearly identical to using a zio.Ref
: We’re decrementing the Int
contained into the TRef[Int]
except that we’re using .commit
to turn commit the STM[Nothing, Int]
, yielding an UIO[Int]
.
In contrast to Ref
, we could have safely written the code like this:
1
2
3
4
5
val countDown: UIO[Unit] =
(for {
x <- count.get
_ <- count.set(x - 1)
} yield ()).commit
Individual values—like TRref
s—that take part of a transaction won’t be visible to other transactions unless they are commited. If the count
is updated somewhere else the whole transaction will be retried.
await
1
2
val await: UIO[Unit] =
count.get.collect { case n if n <= 0 => () }.commit
This reads like this: get
the count
, if it is <= 0
, then then return the unit value ()
, otherwise retry.
Here it a version that doesn’t use collect
:
1
2
3
4
5
val await: UIO[Unit] =
(for {
x <- count.get
_ <- if (x <= 0) STM.unit else STM.retry
} yield ()).commit
STM.retry
won’t use a spinlock - instead it will semantically block the Fiber waiting on .commit
- and will retry the transaction when any of the values that take part of the transaction change