Skip to content

Commit

Permalink
Remote UTF8 decoder (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
mijicd authored Mar 31, 2023
1 parent a993b0b commit 266abfb
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 78 deletions.
4 changes: 3 additions & 1 deletion modules/redis/src/main/scala/zio/redis/Output.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import zio.redis.options.Cluster.{Node, Partition, SlotRange}
import zio.schema.Schema
import zio.schema.codec.BinaryCodec

import java.nio.charset.StandardCharsets

sealed trait Output[+A] { self =>
protected def tryDecode(respValue: RespValue): A

Expand Down Expand Up @@ -718,7 +720,7 @@ object Output {
}

private def decodeDouble(bytes: Chunk[Byte]): Double = {
val text = RespValue.decode(bytes)
val text = new String(bytes.toArray, StandardCharsets.UTF_8)
try text.toDouble
catch {
case _: NumberFormatException => throw ProtocolError(s"'$text' isn't a double.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ final class SingleNodeExecutor private (

while (it.hasNext) {
val req = it.next()
buffer ++= RespValue.Array(req.command).serialize
buffer ++= RespValue.Array(req.command).asBytes
}

val bytes = buffer.result()
Expand Down
100 changes: 58 additions & 42 deletions modules/redis/src/main/scala/zio/redis/internal/RespValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
import RespValue._
import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded}

final def serialize: Chunk[Byte] =
final def asBytes: Chunk[Byte] =
self match {
case NullBulkString => NullStringEncoded
case NullArray => NullArrayEncoded
Expand All @@ -39,7 +39,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
Headers.BulkString +: (encode(bytes.length.toString) ++ bytes ++ CrLf)

case Array(elements) =>
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.serialize)
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.asBytes)
Headers.Array +: (encode(elements.size.toString) ++ data)
}

Expand Down Expand Up @@ -72,9 +72,9 @@ private[redis] object RespValue {
final case class Integer(value: Long) extends RespValue

final case class BulkString(value: Chunk[Byte]) extends RespValue {
def asLong: Long = internal.unsafeReadLong(asString, 0)
def asLong: Long = internal.unsafeReadLong(value, 0)

def asString: String = decode(value)
def asString: String = internal.decode(value)
}

final case class Array(values: Chunk[RespValue]) extends RespValue
Expand All @@ -96,24 +96,20 @@ private[redis] object RespValue {

// ZSink fold will return a State.Start when contFn is false
val lineProcessor =
ZSink.fold[String, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
case State.Done(value) => ZIO.succeed(Some(value))
case State.Failed => ZIO.fail(RedisError.ProtocolError("Invalid data received."))
case State.Start => ZIO.succeed(None)
case other => ZIO.dieMessage(s"Deserialization bug, should not get $other")
}

(ZPipeline.utf8Decode >>> ZPipeline.splitOn(internal.CrLfString))
.mapError(e => RedisError.ProtocolError(e.getLocalizedMessage))
.andThen(ZPipeline.fromSink(lineProcessor))
ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor)
}

def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))

def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))

def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

private object internal {
object Headers {
final val SimpleString: Byte = '+'
Expand All @@ -124,11 +120,10 @@ private[redis] object RespValue {
}

final val CrLf: Chunk[Byte] = Chunk('\r', '\n')
final val CrLfString: String = "\r\n"
final val NullArrayEncoded: Chunk[Byte] = Chunk.fromArray("*-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullArrayPrefix: String = "*-1"
final val NullStringEncoded: Chunk[Byte] = Chunk.fromArray("$-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullStringPrefix: String = "$-1"
final val NullArrayEncoded: Chunk[Byte] = Chunk('*', '-', '1', '\r', '\n')
final val NullArrayPrefix: Chunk[Byte] = Chunk('*', '-', '1')
final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n')
final val NullStringPrefix: Chunk[Byte] = Chunk('$', '-', '1')

sealed trait State { self =>
import State._
Expand All @@ -139,22 +134,22 @@ private[redis] object RespValue {
case _ => true
}

final def feed(line: String): State =
final def feed(bytes: Chunk[Byte]): State =
self match {
case Start if line.isEmpty() => Start
case Start if line == NullStringPrefix => Done(NullBulkString)
case Start if line == NullArrayPrefix => Done(NullArray)

case Start if line.nonEmpty =>
line.head match {
case Headers.SimpleString => Done(SimpleString(line.tail))
case Headers.Error => Done(Error(line.tail))
case Headers.Integer => Done(Integer(unsafeReadLong(line, 1)))
case Start if bytes.isEmpty => Start
case Start if bytes == NullStringPrefix => Done(NullBulkString)
case Start if bytes == NullArrayPrefix => Done(NullArray)

case Start if bytes.nonEmpty =>
bytes.head match {
case Headers.SimpleString => Done(SimpleString(decode(bytes.tail)))
case Headers.Error => Done(Error(decode(bytes.tail)))
case Headers.Integer => Done(Integer(unsafeReadLong(bytes, 1)))
case Headers.BulkString =>
val size = unsafeReadLong(line, 1).toInt
CollectingBulkString(size, new StringBuilder(size))
val size = unsafeReadSize(bytes)
CollectingBulkString(size, ChunkBuilder.make(size))
case Headers.Array =>
val size = unsafeReadLong(line, 1).toInt
val size = unsafeReadSize(bytes)

if (size > 0)
CollectingArray(size, ChunkBuilder.make(size), Start.feed)
Expand All @@ -165,50 +160,71 @@ private[redis] object RespValue {
}

case CollectingArray(rem, vals, next) =>
next(line) match {
next(bytes) match {
case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
case Done(v) => Done(Array((vals += v).result()))
case state => CollectingArray(rem, vals, state.feed)
}

case CollectingBulkString(rem, vals) =>
if (line.length >= rem) {
val stringValue = vals.append(line.substring(0, rem)).toString
Done(BulkString(Chunk.fromArray(stringValue.getBytes(StandardCharsets.UTF_8))))
if (bytes.length >= rem) {
vals ++= bytes.take(rem)
Done(BulkString(vals.result()))
} else {
CollectingBulkString(rem - line.length - 2, vals.append(line).append(CrLfString))
vals ++= bytes
vals ++= CrLf
CollectingBulkString(rem - bytes.length - 2, vals)
}

case _ => Failed
}
}

object State {
case object Start extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: String => State) extends State
final case class CollectingBulkString(rem: Int, vals: StringBuilder) extends State
final case class Done(value: RespValue) extends State
case object Start extends State
case object Failed extends State

final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State)
extends State

final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State

final case class Done(value: RespValue) extends State
}

def unsafeReadLong(text: String, startFrom: Int): Long = {
def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = {
var pos = startFrom
var res = 0L
var neg = false

if (text.charAt(pos) == '-') {
if (bytes(pos) == '-') {
neg = true
pos += 1
}

val len = text.length
val len = bytes.length

while (pos < len) {
res = res * 10 + text.charAt(pos) - '0'
res = res * 10 + bytes(pos) - '0'
pos += 1
}

if (neg) -res else res
}

def unsafeReadSize(bytes: Chunk[Byte]): Int = {
var pos = 1
var res = 0
val len = bytes.length

while (pos < len) {
res = res * 10 + bytes(pos) - '0'
pos += 1
}

res
}
}
}
56 changes: 22 additions & 34 deletions modules/redis/src/test/scala/zio/redis/internal/RespValueSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,35 @@ package zio.redis.internal

import zio.Chunk
import zio.redis._
import zio.stream.ZStream
import zio.test.Assertion._
import zio.test._

import java.nio.charset.StandardCharsets

object RespValueSpec extends BaseSpec {
def spec: Spec[Any, RedisError.ProtocolError] =
suite("RespValue")(
suite("serialization")(
test("array") {
val expected = Chunk.fromArray("*3\r\n$3\r\nabc\r\n:123\r\n$-1\r\n".getBytes(StandardCharsets.UTF_8))
val v = RespValue.array(RespValue.bulkString("abc"), RespValue.Integer(123), RespValue.NullBulkString)
assert(v.serialize)(equalTo(expected))
}
),
suite("deserialization")(
test("array") {
val values = Chunk(
RespValue.SimpleString("OK"),
test("serializes and deserializes messages") {
val values = Chunk(
RespValue.SimpleString("OK"),
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)

zio.stream.ZStream
.fromChunk(values)
.mapConcat(_.serialize)
.via(RespValue.Decoder)
.collect { case Some(value) =>
value
}
.runCollect
.map(assert(_)(equalTo(values)))
}
)
ZStream
.fromChunk(values)
.mapConcat(_.asBytes)
.via(RespValue.Decoder)
.collectSome
.runCollect
.map(assert(_)(equalTo(values)))
}
)
}

0 comments on commit 266abfb

Please sign in to comment.