Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support f64 for ubjson. #10055

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2023 by XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_
Expand Down Expand Up @@ -42,7 +42,8 @@ class Value {
kBoolean,
kNull,
// typed array for ubjson
kNumberArray,
kF32Array,
kF64Array,
kU8Array,
kI32Array,
kI64Array
Expand Down Expand Up @@ -173,7 +174,11 @@ class JsonTypedArray : public Value {
/**
* @brief Typed UBJSON array for 32-bit floating point.
*/
using F32Array = JsonTypedArray<float, Value::ValueKind::kNumberArray>;
using F32Array = JsonTypedArray<float, Value::ValueKind::kF32Array>;
/**
* @brief Typed UBJSON array for 64-bit floating point.
*/
using F64Array = JsonTypedArray<double, Value::ValueKind::kF64Array>;
/**
* @brief Typed UBJSON array for uint8_t.
*/
Expand Down Expand Up @@ -457,9 +462,9 @@ class Json {
Json& operator[](int ind) const { return (*ptr_)[ind]; }

/*! \brief Return the reference to stored Json value. */
Value const& GetValue() const & { return *ptr_; }
Value const& GetValue() && { return *ptr_; }
Value& GetValue() & { return *ptr_; }
[[nodiscard]] Value const& GetValue() const& { return *ptr_; }
Value const& GetValue() && { return *ptr_; }
Value& GetValue() & { return *ptr_; }

bool operator==(Json const& rhs) const {
return *ptr_ == *(rhs.ptr_);
Expand All @@ -472,7 +477,7 @@ class Json {
return os;
}

IntrusivePtr<Value> const& Ptr() const { return ptr_; }
[[nodiscard]] IntrusivePtr<Value> const& Ptr() const { return ptr_; }

private:
IntrusivePtr<Value> ptr_{new JsonNull};
Expand Down
4 changes: 3 additions & 1 deletion include/xgboost/json_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class JsonWriter {

virtual void Visit(JsonArray const* arr);
virtual void Visit(F32Array const* arr);
virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; }
virtual void Visit(U8Array const* arr);
virtual void Visit(I32Array const* arr);
virtual void Visit(I64Array const* arr);
Expand Down Expand Up @@ -244,7 +245,8 @@ class UBJReader : public JsonReader {
*/
class UBJWriter : public JsonWriter {
void Visit(JsonArray const* arr) override;
void Visit(F32Array const* arr) override;
void Visit(F32Array const* arr) override;
void Visit(F64Array const* arr) override;
void Visit(U8Array const* arr) override;
void Visit(I32Array const* arr) override;
void Visit(I64Array const* arr) override;
Expand Down
38 changes: 24 additions & 14 deletions src/common/json.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/**
* Copyright 2019-2023, XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include "xgboost/json.h"

#include <array> // for array
#include <cctype> // for isdigit
#include <cmath> // for isinf, isnan
#include <cstdint> // for uint8_t, uint16_t, uint32_t
#include <cstdio> // for EOF
#include <cstdlib> // for size_t, strtof
#include <cstring> // for memcpy
Expand Down Expand Up @@ -72,15 +73,16 @@ void JsonWriter::Visit(JsonNumber const* num) {
}

void JsonWriter::Visit(JsonInteger const* num) {
char i2s_buffer_[NumericLimits<int64_t>::kToCharsSize];
std::array<char, NumericLimits<int64_t>::kToCharsSize> i2s_buffer_;
auto i = num->GetInteger();
auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits<int64_t>::kToCharsSize, i);
auto ret =
to_chars(i2s_buffer_.data(), i2s_buffer_.data() + NumericLimits<int64_t>::kToCharsSize, i);
auto end = ret.ptr;
CHECK(ret.ec == std::errc());
auto digits = std::distance(i2s_buffer_, end);
auto digits = std::distance(i2s_buffer_.data(), end);
auto ori_size = stream_->size();
stream_->resize(ori_size + digits);
std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits);
std::memcpy(stream_->data() + ori_size, i2s_buffer_.data(), digits);
}

void JsonWriter::Visit(JsonNull const* ) {
Expand Down Expand Up @@ -143,8 +145,10 @@ std::string Value::TypeStr() const {
return "Null";
case ValueKind::kInteger:
return "Integer";
case ValueKind::kNumberArray:
case ValueKind::kF32Array:
return "F32Array";
case ValueKind::kF64Array:
return "F64Array";
case ValueKind::kU8Array:
return "U8Array";
case ValueKind::kI32Array:
Expand Down Expand Up @@ -262,10 +266,11 @@ bool JsonTypedArray<T, kind>::operator==(Value const& rhs) const {
return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin());
}

template class JsonTypedArray<float, Value::ValueKind::kNumberArray>;
template class JsonTypedArray<uint8_t, Value::ValueKind::kU8Array>;
template class JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
template class JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
template class JsonTypedArray<float, Value::ValueKind::kF32Array>;
template class JsonTypedArray<double, Value::ValueKind::kF64Array>;
template class JsonTypedArray<std::uint8_t, Value::ValueKind::kU8Array>;
template class JsonTypedArray<std::int32_t, Value::ValueKind::kI32Array>;
template class JsonTypedArray<std::int64_t, Value::ValueKind::kI64Array>;

// Json Number
bool JsonNumber::operator==(Value const& rhs) const {
Expand Down Expand Up @@ -708,6 +713,8 @@ Json UBJReader::ParseArray() {
switch (type) {
case 'd':
return ParseTypedArray<F32Array>(n);
case 'D':
return ParseTypedArray<F64Array>(n);
case 'U':
return ParseTypedArray<U8Array>(n);
case 'l':
Expand Down Expand Up @@ -797,6 +804,10 @@ Json UBJReader::Parse() {
auto v = this->ReadPrimitive<float>();
return Json{v};
}
case 'D': {
auto v = this->ReadPrimitive<double>();
return Json{v};
}
case 'S': {
auto str = this->DecodeStr();
return Json{str};
Expand Down Expand Up @@ -825,10 +836,6 @@ Json UBJReader::Parse() {
Integer::Int i = this->ReadPrimitive<char>();
return Json{i};
}
case 'D': {
LOG(FATAL) << "f64 is not supported.";
break;
}
case 'H': {
LOG(FATAL) << "High precision number is not supported.";
break;
Expand Down Expand Up @@ -882,6 +889,8 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre
stream->push_back('$');
if (std::is_same<T, float>::value) {
stream->push_back('d');
} else if (std::is_same_v<T, double>) {
stream->push_back('D');
} else if (std::is_same<T, int8_t>::value) {
stream->push_back('i');
} else if (std::is_same<T, uint8_t>::value) {
Expand Down Expand Up @@ -910,6 +919,7 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre
}

void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); }
void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); }
Expand Down
35 changes: 35 additions & 0 deletions tests/cpp/common/test_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,40 @@ TEST(Json, TypedArray) {
ASSERT_EQ(arr[i + 8], i);
}
}

{
Json f64{Object{}};
auto array = F64Array();
auto& vec = array.GetArray();
// Construct test data
vec.resize(18);
std::iota(vec.begin(), vec.end(), 0.0);
// special values
vec.push_back(std::numeric_limits<double>::epsilon());
vec.push_back(std::numeric_limits<double>::max());
vec.push_back(std::numeric_limits<double>::min());
vec.push_back(std::numeric_limits<double>::denorm_min());
vec.push_back(std::numeric_limits<double>::quiet_NaN());

static_assert(
std::is_same_v<double, typename std::remove_reference_t<decltype(vec)>::value_type>);

f64["f64"] = std::move(array);
ASSERT_TRUE(IsA<F64Array>(f64["f64"]));
std::vector<char> out;
Json::Dump(f64, &out, std::ios::binary);

auto loaded = Json::Load(StringView{out.data(), out.size()}, std::ios::binary);
ASSERT_TRUE(IsA<F64Array>(loaded["f64"]));
auto const& result = get<F64Array const>(loaded["f64"]);

auto& vec1 = get<F64Array const>(f64["f64"]);
ASSERT_EQ(result.size(), vec1.size());
for (std::size_t i = 0; i < vec1.size() - 1; ++i) {
ASSERT_EQ(result[i], vec1[i]);
}
ASSERT_TRUE(std::isnan(result.back()));
}
}

TEST(UBJson, Basic) {
Expand Down Expand Up @@ -694,6 +728,7 @@ TEST(UBJson, Basic) {
}
}


TEST(Json, TypeCheck) {
Json config{Object{}};
config["foo"] = String{"bar"};
Expand Down
11 changes: 10 additions & 1 deletion tests/cpp/test_serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void CompareJSON(Json l, Json r) {
}
break;
}
case Value::ValueKind::kNumberArray: {
case Value::ValueKind::kF32Array: {
auto const& l_arr = get<F32Array const>(l);
auto const& r_arr = get<F32Array const>(r);
ASSERT_EQ(l_arr.size(), r_arr.size());
Expand All @@ -69,6 +69,15 @@ void CompareJSON(Json l, Json r) {
}
break;
}
case Value::ValueKind::kF64Array: {
auto const& l_arr = get<F64Array const>(l);
auto const& r_arr = get<F64Array const>(r);
ASSERT_EQ(l_arr.size(), r_arr.size());
for (size_t i = 0; i < l_arr.size(); ++i) {
ASSERT_NEAR(l_arr[i], r_arr[i], kRtEps);
}
break;
}
case Value::ValueKind::kU8Array: {
CompareIntArray<U8Array>(l, r);
break;
Expand Down
Loading