Skip to content

Commit

Permalink
[KIKIMR-22131] Handle potential race in computation pattern cache (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
abyss7 authored Feb 25, 2025
1 parent 09bfb76 commit abfc4e2
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 112 deletions.
24 changes: 15 additions & 9 deletions ydb/library/yql/dq/runtime/dq_tasks_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,23 @@ class TDqTaskRunner : public IDqTaskRunner {
bool canBeCached;
if (UseSeparatePatternAlloc(task) && Context.PatternCache) {
auto& cache = Context.PatternCache;
auto ticket = cache->FindOrSubscribe(program.GetRaw());
if (!ticket.HasFuture()) {
entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached);
if (canBeCached && entry->Pattern->GetSuitableForCache()) {
cache->EmplacePattern(task.GetProgram().GetRaw(), entry);
ticket.Close();
} else {
cache->IncNotSuitablePattern();
auto future = cache->FindOrSubscribe(program.GetRaw());
if (!future.HasValue()) {
try {
entry = CreateComputationPattern(task, program.GetRaw(), true, canBeCached);
if (canBeCached && entry->Pattern->GetSuitableForCache()) {
cache->EmplacePattern(task.GetProgram().GetRaw(), entry);
} else {
cache->IncNotSuitablePattern();
cache->NotifyPatternMissing(program.GetRaw());
}
} catch (...) {
// TODO: not sure if there may be exceptions in the first place.
cache->NotifyPatternMissing(program.GetRaw());
throw;
}
} else {
entry = ticket.GetValueSync();
entry = future.GetValueSync();
}
}

Expand Down
103 changes: 51 additions & 52 deletions yql/essentials/minikql/computation/mkql_computation_pattern_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,29 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
return CurrentPatternsCompiledCodeSizeInBytes;
}

std::shared_ptr<TPatternCacheEntry>* Find(const TString& serializedProgram) {
TPatternCacheEntryPtr Find(const TString& serializedProgram) {
auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram);
if (it == SerializedProgramToPatternCacheHolder.end()) {
return nullptr;
return {};
}

PromoteEntry(&it->second);

return &it->second.Entry;
return it->second.Entry;
}

void Insert(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry>& entry) {
void Insert(const TString& serializedProgram, TPatternCacheEntryPtr& entry) {
auto [it, inserted] = SerializedProgramToPatternCacheHolder.emplace(std::piecewise_construct,
std::forward_as_tuple(serializedProgram),
std::forward_as_tuple(serializedProgram, entry));

if (!inserted) {
RemoveEntryFromLists(&it->second);
entry = it->second.Entry;
} else {
entry->UpdateSizeForCache();
}

entry->UpdateSizeForCache();

/// New item is inserted, insert it in the back of both LRU lists and recalculate sizes
CurrentPatternsSizeBytes += entry->SizeForCache;
LRUPatternList.PushBack(&it->second);
Expand All @@ -69,17 +70,18 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
ClearIfNeeded();
}

void NotifyPatternCompiled(const TString & serializedProgram) {
void NotifyPatternCompiled(const TString& serializedProgram) {
auto it = SerializedProgramToPatternCacheHolder.find(serializedProgram);
if (it == SerializedProgramToPatternCacheHolder.end()) {
return;
}

const auto& entry = it->second.Entry;

// TODO(ilezhankin): wait until migration of yql to arcadia is complete and merge the proper fix from here:
// https://github.com/ydb-platform/ydb/pull/11129
if (!entry->Pattern->IsCompiled()) {
// This is possible if the old entry got removed from cache while being compiled - and the new entry got in.
// TODO: add metrics for this inefficient cache usage.
// TODO: make this scenario more consistent - don't waste compilation result.
return;
}

Expand Down Expand Up @@ -117,7 +119,7 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
* Most recently accessed items are in back of the lists, least recently accessed items are in front of the lists.
*/
struct TPatternCacheHolder : public TIntrusiveListItem<TPatternCacheHolder, TPatternLRUListTag>, TIntrusiveListItem<TPatternCacheHolder, TCompiledPatternLRUListTag> {
TPatternCacheHolder(TString serializedProgram, std::shared_ptr<TPatternCacheEntry> entry)
TPatternCacheHolder(TString serializedProgram, TPatternCacheEntryPtr entry)
: SerializedProgram(std::move(serializedProgram))
, Entry(std::move(entry))
{}
Expand All @@ -130,8 +132,8 @@ class TComputationPatternLRUCache::TLRUPatternCacheImpl
return !TIntrusiveListItem<TPatternCacheHolder, TCompiledPatternLRUListTag>::Empty();
}

TString SerializedProgram;
std::shared_ptr<TPatternCacheEntry> Entry;
const TString SerializedProgram;
TPatternCacheEntryPtr Entry;
};

void PromoteEntry(TPatternCacheHolder* holder) {
Expand Down Expand Up @@ -232,52 +234,51 @@ TComputationPatternLRUCache::~TComputationPatternLRUCache() {
CleanCache();
}

std::shared_ptr<TPatternCacheEntry> TComputationPatternLRUCache::Find(const TString& serializedProgram) {
TPatternCacheEntryPtr TComputationPatternLRUCache::Find(const TString& serializedProgram) {
std::lock_guard<std::mutex> lock(Mutex);
if (auto it = Cache->Find(serializedProgram)) {
++*Hits;

if ((*it)->Pattern->IsCompiled())
if (it->Pattern->IsCompiled())
++*HitsCompiled;

return *it;
return it;
}

++*Misses;
return {};
}

TComputationPatternLRUCache::TTicket TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) {
TPatternCacheEntryFuture TComputationPatternLRUCache::FindOrSubscribe(const TString& serializedProgram) {
std::lock_guard lock(Mutex);
if (auto it = Cache->Find(serializedProgram)) {
++*Hits;
AccessPattern(serializedProgram, *it);
return TTicket(serializedProgram, false, NThreading::MakeFuture<std::shared_ptr<TPatternCacheEntry>>(*it), nullptr);
AccessPattern(serializedProgram, it);
return NThreading::MakeFuture<TPatternCacheEntryPtr>(it);
}

auto [notifyIt, isNew] = Notify.emplace(serializedProgram, Nothing());
auto [notifyIt, isNew] = Notify.emplace(std::piecewise_construct, std::forward_as_tuple(serializedProgram), std::forward_as_tuple());
if (isNew) {
++*Misses;
return TTicket(serializedProgram, true, {}, this);
// First future is empty - so the subscriber can initiate the entry creation.
return {};
}

++*Waits;
auto promise = NThreading::NewPromise<std::shared_ptr<TPatternCacheEntry>>();
auto promise = NThreading::NewPromise<TPatternCacheEntryPtr>();
auto& subscribers = notifyIt->second;
if (!subscribers) {
subscribers.ConstructInPlace();
}
subscribers.push_back(promise);

subscribers->push_back(promise);
return TTicket(serializedProgram, false, promise, nullptr);
// Second and next futures are not empty - so subscribers can wait while first one creates the entry.
return promise;
}

void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry> patternWithEnv) {
void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv) {
Y_DEBUG_ABORT_UNLESS(patternWithEnv && patternWithEnv->Pattern);
TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>> subscribers;
TVector<NThreading::TPromise<TPatternCacheEntryPtr>> subscribers;

{
std::lock_guard<std::mutex> lock(Mutex);
std::lock_guard lock(Mutex);
Cache->Insert(serializedProgram, patternWithEnv);

auto notifyIt = Notify.find(serializedProgram);
Expand All @@ -292,10 +293,8 @@ void TComputationPatternLRUCache::EmplacePattern(const TString& serializedProgra
*SizeCompiledBytes = Cache->PatternsCompiledCodeSizeInBytes();
}

if (subscribers) {
for (auto& subscriber : *subscribers) {
subscriber.SetValue(patternWithEnv);
}
for (auto& subscriber : subscribers) {
subscriber.SetValue(patternWithEnv);
}
}

Expand All @@ -304,6 +303,24 @@ void TComputationPatternLRUCache::NotifyPatternCompiled(const TString& serialize
Cache->NotifyPatternCompiled(serializedProgram);
}

void TComputationPatternLRUCache::NotifyPatternMissing(const TString& serializedProgram) {
TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>> subscribers;
{
std::lock_guard lock(Mutex);

auto notifyIt = Notify.find(serializedProgram);
if (notifyIt != Notify.end()) {
subscribers.swap(notifyIt->second);
Notify.erase(notifyIt);
}
}

for (auto& subscriber : subscribers) {
// It's part of API - to set nullptr as broken promise.
subscriber.SetValue(nullptr);
}
}

size_t TComputationPatternLRUCache::GetSize() const {
std::lock_guard lock(Mutex);
return Cache->PatternsSize();
Expand All @@ -318,7 +335,7 @@ void TComputationPatternLRUCache::CleanCache() {
Cache->Clear();
}

void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgram, std::shared_ptr<TPatternCacheEntry> & entry) {
void TComputationPatternLRUCache::AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry) {
if (!Configuration.PatternAccessTimesBeforeTryToCompile || entry->Pattern->IsCompiled()) {
return;
}
Expand All @@ -330,22 +347,4 @@ void TComputationPatternLRUCache::AccessPattern(const TString & serializedProgra
}
}

void TComputationPatternLRUCache::NotifyMissing(const TString& serialized) {
TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>> subscribers;
{
std::lock_guard<std::mutex> lock(Mutex);
auto notifyIt = Notify.find(serialized);
if (notifyIt != Notify.end()) {
subscribers.swap(notifyIt->second);
Notify.erase(notifyIt);
}
}

if (subscribers) {
for (auto& subscriber : *subscribers) {
subscriber.SetValue(nullptr);
}
}
}

} // namespace NKikimr::NMiniKQL
65 changes: 14 additions & 51 deletions yql/essentials/minikql/computation/mkql_computation_pattern_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,11 @@ struct TPatternCacheEntry {
}
};

using TPatternCacheEntryPtr = std::shared_ptr<TPatternCacheEntry>;
using TPatternCacheEntryFuture = NThreading::TFuture<TPatternCacheEntryPtr>;

class TComputationPatternLRUCache {
public:
class TTicket : private TNonCopyable {
public:
TTicket(const TString& serialized, bool isOwned, const NThreading::TFuture<std::shared_ptr<TPatternCacheEntry>>& future, TComputationPatternLRUCache* cache)
: Serialized(serialized)
, IsOwned(isOwned)
, Future(future)
, Cache(cache)
{}

~TTicket() {
if (Cache) {
Cache->NotifyMissing(Serialized);
}
}

bool HasFuture() const {
return !IsOwned;
}

std::shared_ptr<TPatternCacheEntry> GetValueSync() const {
Y_ABORT_UNLESS(HasFuture());
return Future.GetValueSync();
}

void Close() {
Cache = nullptr;
}

private:
const TString Serialized;
const bool IsOwned;
const NThreading::TFuture<std::shared_ptr<TPatternCacheEntry>> Future;
TComputationPatternLRUCache* Cache;
};

struct Config {
Config(size_t maxSizeBytes, size_t maxCompiledSizeBytes)
: MaxSizeBytes(maxSizeBytes)
Expand Down Expand Up @@ -121,17 +89,17 @@ class TComputationPatternLRUCache {

~TComputationPatternLRUCache();

static std::shared_ptr<TPatternCacheEntry> CreateCacheEntry(bool useAlloc = true) {
static TPatternCacheEntryPtr CreateCacheEntry(bool useAlloc = true) {
return std::make_shared<TPatternCacheEntry>(useAlloc);
}

std::shared_ptr<TPatternCacheEntry> Find(const TString& serializedProgram);
TPatternCacheEntryPtr Find(const TString& serializedProgram);
TPatternCacheEntryFuture FindOrSubscribe(const TString& serializedProgram);

TTicket FindOrSubscribe(const TString& serializedProgram);

void EmplacePattern(const TString& serializedProgram, std::shared_ptr<TPatternCacheEntry> patternWithEnv);
void EmplacePattern(const TString& serializedProgram, TPatternCacheEntryPtr& patternWithEnv);

void NotifyPatternCompiled(const TString& serializedProgram);
void NotifyPatternMissing(const TString& serializedProgram);

size_t GetSize() const;

Expand Down Expand Up @@ -160,27 +128,22 @@ class TComputationPatternLRUCache {
return PatternsToCompile.size();
}

void GetPatternsToCompile(THashMap<TString, std::shared_ptr<TPatternCacheEntry>> & result) {
void GetPatternsToCompile(THashMap<TString, TPatternCacheEntryPtr> & result) {
std::lock_guard lock(Mutex);
result.swap(PatternsToCompile);
}

private:
void AccessPattern(const TString & serializedProgram, std::shared_ptr<TPatternCacheEntry> & entry);

void NotifyMissing(const TString& serialized);
class TLRUPatternCacheImpl;

static constexpr size_t CacheMaxElementsSize = 10000;

friend class TTicket;
void AccessPattern(const TString& serializedProgram, TPatternCacheEntryPtr entry);

mutable std::mutex Mutex;
THashMap<TString, TMaybe<TVector<NThreading::TPromise<std::shared_ptr<TPatternCacheEntry>>>>> Notify;

class TLRUPatternCacheImpl;
std::unique_ptr<TLRUPatternCacheImpl> Cache;

THashMap<TString, std::shared_ptr<TPatternCacheEntry>> PatternsToCompile;
THashMap<TString, TVector<NThreading::TPromise<TPatternCacheEntryPtr>>> Notify; // protected by Mutex
std::unique_ptr<TLRUPatternCacheImpl> Cache; // protected by Mutex
THashMap<TString, TPatternCacheEntryPtr> PatternsToCompile; // protected by Mutex

const Config Configuration;

Expand Down

0 comments on commit abfc4e2

Please sign in to comment.