Ferdinand Svehla

Ferdinand Svehla

Developer.

© 2020

ZIO STM: CountDownLatch in two lines

ZIO STM: CountDownLatch in (effectively) two lines

Inspired by Wix’ Amitay Horwitz’ post on Creating a dead simple CountDownLatch with ZIO, here’s a CountDownLatch in (effectively) two lines by using ZIO STM.

The code

First the code—we’ll dissect it afterwards

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
package dev.fsvehla

import zio._
import zio.stm._

final class CountDownLatch private (count: TRef[Int]) {
  val countDown: UIO[Unit] =
    count.update(_ - 1).commit.unit

  val await: UIO[Unit] =
    count.get.collect { case n if n <= 0 => () }.commit
}

object CountDownLatch {
  def make(count: Int): UIO[CountDownLatch] =
    TRef.make(count).map(r => new CountDownLatch(r)).commit
}

And here is a a test suite—using ZIO test—yielding a proper code:test ratio.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package dev.fsvehla

import zio.test.Assertion._
import zio.test._
import zio.Schedule

//noinspection TypeAnnotation
object CountDownLatchSpec extends DefaultRunnableSpec {
  def spec =
    suite("CountDownLatch")(
      testM("notifies fibers that wait on completion") {
        for {
          latch        <- CountDownLatch.make(5)
          reader1Fiber <- latch.await.fork
          reader2Fiber <- latch.await.fork
          _            <- latch.countDown.repeat(Schedule.recurs(4))
          _            <- reader1Fiber.join
          _            <- reader2Fiber.join
        } yield {
          assertCompletes
        }
      },
      testM("notifies new fibers on completion") {
        for {
          latch <- CountDownLatch.make(5)
          _     <- latch.countDown.repeat(Schedule.recurs(4))
          _     <- latch.await
        } yield {
          assertCompletes
        }
      }
    )
}

Note well that ZIO.repeat(Schedule.recurs(4)) means that the Schedule is repeated four times in addition to the initial invocation.

So, let dig into the code, starting with creation of the CountDownLatch

CountDownLatch.countDown

1
2
def make(count: Int): UIO[CountDownLatch] =
  TRef.make(count).map(r => new CountDownLatch(r)).commit

First of all, the signature tells us that the creation of the CountDownLatch is effectful and can’t fail. UIO[A] is a type alias for ZIO[Any, Nothing, A], which reads like this: has no requirements, is infallible, and will return A.

We’re creating a TRef - a transactional reference, containing the passed in count, and create a new instance of our CountDownLatch.

countDown

1
2
val countDown: UIO[Unit] =
  count.update(_ - 1).commit.unit

This code is nearly identical to using a zio.Ref: We’re decrementing the Int contained into the TRef[Int] except that we’re using .commit to turn commit the STM[Nothing, Int], yielding an UIO[Int].

In contrast to Ref, we could have safely written the code like this:

1
2
3
4
5
val countDown: UIO[Unit] =
  (for {
    x <- count.get
    _ <- count.set(x - 1)
  } yield ()).commit

Individual values—like TRrefs—that take part of a transaction won’t be visible to other transactions unless they are commited. If the count is updated somewhere else the whole transaction will be retried.

await

1
2
val await: UIO[Unit] =
  count.get.collect { case n if n <= 0 => () }.commit

This reads like this: get the count, if it is <= 0, then then return the unit value (), otherwise retry. Here it a version that doesn’t use collect:

1
2
3
4
5
val await: UIO[Unit] =
  (for {
    x <- count.get
    _ <- if (x <= 0) STM.unit else STM.retry
  } yield ()).commit

STM.retry won’t use a spinlock - instead it will semantically block the Fiber waiting on .commit - and will retry the transaction when any of the values that take part of the transaction change