ZIO STM: CountDownLatch in (effectively) two lines
Inspired by Wix’ Amitay Horwitz’ post on Creating a dead simple CountDownLatch with ZIO, here’s a CountDownLatch in (effectively) two lines by using ZIO STM.
The code
First the code—we’ll dissect it afterwards
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
package dev.fsvehla
import zio._
import zio.stm._
final class CountDownLatch private (count: TRef[Int]) {
val countDown: UIO[Unit] =
count.update(_ - 1).commit.unit
val await: UIO[Unit] =
count.get.collect { case n if n <= 0 => () }.commit
}
object CountDownLatch {
def make(count: Int): UIO[CountDownLatch] =
TRef.make(count).map(r => new CountDownLatch(r)).commit
}
And here is a a test suite—using ZIO test—yielding a proper code:test ratio.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package dev.fsvehla
import zio.test.Assertion._
import zio.test._
import zio.Schedule
//noinspection TypeAnnotation
object CountDownLatchSpec extends DefaultRunnableSpec {
def spec =
suite("CountDownLatch")(
testM("notifies fibers that wait on completion") {
for {
latch <- CountDownLatch.make(5)
reader1Fiber <- latch.await.fork
reader2Fiber <- latch.await.fork
_ <- latch.countDown.repeat(Schedule.recurs(4))
_ <- reader1Fiber.join
_ <- reader2Fiber.join
} yield {
assertCompletes
}
},
testM("notifies new fibers on completion") {
for {
latch <- CountDownLatch.make(5)
_ <- latch.countDown.repeat(Schedule.recurs(4))
_ <- latch.await
} yield {
assertCompletes
}
}
)
}
Note well that ZIO.repeat(Schedule.recurs(4)) means that the Schedule is repeated four times in addition to the initial invocation.
So, let dig into the code, starting with creation of the CountDownLatch…
CountDownLatch.countDown
1
2
def make(count: Int): UIO[CountDownLatch] =
TRef.make(count).map(r => new CountDownLatch(r)).commit
First of all, the signature tells us that the creation of the CountDownLatch is effectful and can’t fail.
UIO[A] is a type alias for ZIO[Any, Nothing, A], which reads like this: has no requirements, is infallible, and will return A.
We’re creating a TRef - a transactional reference, containing the passed in count, and create a new instance of our CountDownLatch.
countDown
1
2
val countDown: UIO[Unit] =
count.update(_ - 1).commit.unit
This code is nearly identical to using a zio.Ref: We’re decrementing the Int contained into the TRef[Int] except that we’re using .commit to turn commit the STM[Nothing, Int], yielding an UIO[Int].
In contrast to Ref, we could have safely written the code like this:
1
2
3
4
5
val countDown: UIO[Unit] =
(for {
x <- count.get
_ <- count.set(x - 1)
} yield ()).commit
Individual values—like TRrefs—that take part of a transaction won’t be visible to other transactions unless they are commited. If the count is updated somewhere else the whole transaction will be retried.
await
1
2
val await: UIO[Unit] =
count.get.collect { case n if n <= 0 => () }.commit
This reads like this: get the count, if it is <= 0, then then return the unit value (), otherwise retry.
Here it a version that doesn’t use collect:
1
2
3
4
5
val await: UIO[Unit] =
(for {
x <- count.get
_ <- if (x <= 0) STM.unit else STM.retry
} yield ()).commit
STM.retry won’t use a spinlock - instead it will semantically block the Fiber waiting on .commit - and will retry the transaction when any of the values that take part of the transaction change