Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote UTF8 decoder #796

Merged
merged 8 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modules/redis/src/main/scala/zio/redis/Output.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import zio.redis.options.Cluster.{Node, Partition, SlotRange}
import zio.schema.Schema
import zio.schema.codec.BinaryCodec

import java.nio.charset.StandardCharsets

sealed trait Output[+A] { self =>
protected def tryDecode(respValue: RespValue): A

Expand Down Expand Up @@ -718,7 +720,7 @@ object Output {
}

private def decodeDouble(bytes: Chunk[Byte]): Double = {
val text = RespValue.decode(bytes)
val text = new String(bytes.toArray, StandardCharsets.UTF_8)
try text.toDouble
catch {
case _: NumberFormatException => throw ProtocolError(s"'$text' isn't a double.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ final class SingleNodeExecutor private (

while (it.hasNext) {
val req = it.next()
buffer ++= RespValue.Array(req.command).serialize
buffer ++= RespValue.Array(req.command).asBytes
}

val bytes = buffer.result()
Expand Down
100 changes: 58 additions & 42 deletions modules/redis/src/main/scala/zio/redis/internal/RespValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
import RespValue._
import RespValue.internal.{CrLf, Headers, NullArrayEncoded, NullStringEncoded}

final def serialize: Chunk[Byte] =
final def asBytes: Chunk[Byte] =
self match {
case NullBulkString => NullStringEncoded
case NullArray => NullArrayEncoded
Expand All @@ -39,7 +39,7 @@ private[redis] sealed trait RespValue extends Product with Serializable { self =
Headers.BulkString +: (encode(bytes.length.toString) ++ bytes ++ CrLf)

case Array(elements) =>
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.serialize)
val data = elements.foldLeft[Chunk[Byte]](Chunk.empty)(_ ++ _.asBytes)
Headers.Array +: (encode(elements.size.toString) ++ data)
}

Expand Down Expand Up @@ -72,9 +72,9 @@ private[redis] object RespValue {
final case class Integer(value: Long) extends RespValue

final case class BulkString(value: Chunk[Byte]) extends RespValue {
def asLong: Long = internal.unsafeReadLong(asString, 0)
def asLong: Long = internal.unsafeReadLong(value, 0)

def asString: String = decode(value)
def asString: String = internal.decode(value)
}

final case class Array(values: Chunk[RespValue]) extends RespValue
Expand All @@ -96,24 +96,20 @@ private[redis] object RespValue {

// ZSink fold will return a State.Start when contFn is false
val lineProcessor =
ZSink.fold[String, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
ZSink.foldChunks[Byte, State](State.Start)(_.inProgress)(_ feed _).mapZIO {
case State.Done(value) => ZIO.succeed(Some(value))
case State.Failed => ZIO.fail(RedisError.ProtocolError("Invalid data received."))
case State.Start => ZIO.succeed(None)
case other => ZIO.dieMessage(s"Deserialization bug, should not get $other")
}

(ZPipeline.utf8Decode >>> ZPipeline.splitOn(internal.CrLfString))
.mapError(e => RedisError.ProtocolError(e.getLocalizedMessage))
.andThen(ZPipeline.fromSink(lineProcessor))
ZPipeline.splitOnChunk(internal.CrLf) >>> ZPipeline.fromSink(lineProcessor)
}

def array(values: RespValue*): Array = Array(Chunk.fromIterable(values))

def bulkString(s: String): BulkString = BulkString(Chunk.fromArray(s.getBytes(StandardCharsets.UTF_8)))

def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

private object internal {
object Headers {
final val SimpleString: Byte = '+'
Expand All @@ -124,11 +120,10 @@ private[redis] object RespValue {
}

final val CrLf: Chunk[Byte] = Chunk('\r', '\n')
final val CrLfString: String = "\r\n"
final val NullArrayEncoded: Chunk[Byte] = Chunk.fromArray("*-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullArrayPrefix: String = "*-1"
final val NullStringEncoded: Chunk[Byte] = Chunk.fromArray("$-1\r\n".getBytes(StandardCharsets.US_ASCII))
final val NullStringPrefix: String = "$-1"
final val NullArrayEncoded: Chunk[Byte] = Chunk('*', '-', '1', '\r', '\n')
final val NullArrayPrefix: Chunk[Byte] = Chunk('*', '-', '1')
final val NullStringEncoded: Chunk[Byte] = Chunk('$', '-', '1', '\r', '\n')
final val NullStringPrefix: Chunk[Byte] = Chunk('$', '-', '1')

sealed trait State { self =>
import State._
Expand All @@ -139,22 +134,22 @@ private[redis] object RespValue {
case _ => true
}

final def feed(line: String): State =
final def feed(bytes: Chunk[Byte]): State =
self match {
case Start if line.isEmpty() => Start
case Start if line == NullStringPrefix => Done(NullBulkString)
case Start if line == NullArrayPrefix => Done(NullArray)

case Start if line.nonEmpty =>
line.head match {
case Headers.SimpleString => Done(SimpleString(line.tail))
case Headers.Error => Done(Error(line.tail))
case Headers.Integer => Done(Integer(unsafeReadLong(line, 1)))
case Start if bytes.isEmpty => Start
case Start if bytes == NullStringPrefix => Done(NullBulkString)
case Start if bytes == NullArrayPrefix => Done(NullArray)

case Start if bytes.nonEmpty =>
bytes.head match {
case Headers.SimpleString => Done(SimpleString(decode(bytes.tail)))
case Headers.Error => Done(Error(decode(bytes.tail)))
case Headers.Integer => Done(Integer(unsafeReadLong(bytes, 1)))
case Headers.BulkString =>
val size = unsafeReadLong(line, 1).toInt
CollectingBulkString(size, new StringBuilder(size))
val size = unsafeReadSize(bytes)
CollectingBulkString(size, ChunkBuilder.make(size))
case Headers.Array =>
val size = unsafeReadLong(line, 1).toInt
val size = unsafeReadSize(bytes)

if (size > 0)
CollectingArray(size, ChunkBuilder.make(size), Start.feed)
Expand All @@ -165,50 +160,71 @@ private[redis] object RespValue {
}

case CollectingArray(rem, vals, next) =>
next(line) match {
next(bytes) match {
case Done(v) if rem > 1 => CollectingArray(rem - 1, vals += v, Start.feed)
case Done(v) => Done(Array((vals += v).result()))
case state => CollectingArray(rem, vals, state.feed)
}

case CollectingBulkString(rem, vals) =>
if (line.length >= rem) {
val stringValue = vals.append(line.substring(0, rem)).toString
Done(BulkString(Chunk.fromArray(stringValue.getBytes(StandardCharsets.UTF_8))))
if (bytes.length >= rem) {
vals ++= bytes.take(rem)
Done(BulkString(vals.result()))
} else {
CollectingBulkString(rem - line.length - 2, vals.append(line).append(CrLfString))
vals ++= bytes
vals ++= CrLf
CollectingBulkString(rem - bytes.length - 2, vals)
}

case _ => Failed
}
}

object State {
case object Start extends State
case object Failed extends State
final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: String => State) extends State
final case class CollectingBulkString(rem: Int, vals: StringBuilder) extends State
final case class Done(value: RespValue) extends State
case object Start extends State
case object Failed extends State

final case class CollectingArray(rem: Int, vals: ChunkBuilder[RespValue], next: Chunk[Byte] => State)
extends State

final case class CollectingBulkString(rem: Int, vals: ChunkBuilder[Byte]) extends State

final case class Done(value: RespValue) extends State
}

def unsafeReadLong(text: String, startFrom: Int): Long = {
def decode(bytes: Chunk[Byte]): String = new String(bytes.toArray, StandardCharsets.UTF_8)

def unsafeReadLong(bytes: Chunk[Byte], startFrom: Int): Long = {
var pos = startFrom
var res = 0L
var neg = false

if (text.charAt(pos) == '-') {
if (bytes(pos) == '-') {
neg = true
pos += 1
}

val len = text.length
val len = bytes.length

while (pos < len) {
res = res * 10 + text.charAt(pos) - '0'
res = res * 10 + bytes(pos) - '0'
pos += 1
}

if (neg) -res else res
}

def unsafeReadSize(bytes: Chunk[Byte]): Int = {
var pos = 1
var res = 0
val len = bytes.length

while (pos < len) {
res = res * 10 + bytes(pos) - '0'
pos += 1
}

res
}
}
}
56 changes: 22 additions & 34 deletions modules/redis/src/test/scala/zio/redis/internal/RespValueSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,35 @@ package zio.redis.internal

import zio.Chunk
import zio.redis._
import zio.stream.ZStream
import zio.test.Assertion._
import zio.test._

import java.nio.charset.StandardCharsets

object RespValueSpec extends BaseSpec {
def spec: Spec[Any, RedisError.ProtocolError] =
suite("RespValue")(
suite("serialization")(
test("array") {
val expected = Chunk.fromArray("*3\r\n$3\r\nabc\r\n:123\r\n$-1\r\n".getBytes(StandardCharsets.UTF_8))
val v = RespValue.array(RespValue.bulkString("abc"), RespValue.Integer(123), RespValue.NullBulkString)
assert(v.serialize)(equalTo(expected))
}
),
suite("deserialization")(
test("array") {
val values = Chunk(
RespValue.SimpleString("OK"),
test("serializes and deserializes messages") {
val values = Chunk(
RespValue.SimpleString("OK"),
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.array(
RespValue.bulkString("test1"),
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)
RespValue.Integer(42L),
RespValue.NullBulkString,
RespValue.array(RespValue.SimpleString("a"), RespValue.Integer(0L)),
RespValue.bulkString("in array"),
RespValue.SimpleString("test2")
),
RespValue.NullBulkString
)

zio.stream.ZStream
.fromChunk(values)
.mapConcat(_.serialize)
.via(RespValue.Decoder)
.collect { case Some(value) =>
value
}
.runCollect
.map(assert(_)(equalTo(values)))
}
)
ZStream
.fromChunk(values)
.mapConcat(_.asBytes)
.via(RespValue.Decoder)
.collectSome
.runCollect
.map(assert(_)(equalTo(values)))
}
)
}