Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh ldap token with error (#15182) #15460

Open
wants to merge 1 commit into
base: stable-24-4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ydb/core/security/ldap_auth_provider/ldap_auth_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class TLdapAuthProvider : public NActors::TActorBootstrapped<TLdapAuthProvider>
response.Status = NKikimrLdap::ErrorToStatus(result);
response.Error = {.Message = ERROR_MESSAGE, .LogMessage = logErrorMessage, .Retryable = NKikimrLdap::IsRetryableError(result)};
LDAP_LOG_D(logErrorMessage);
NKikimrLdap::MsgFree(searchMessage);
return response;
}
const int countEntries = NKikimrLdap::CountEntries(request.Ld, searchMessage);
Expand Down Expand Up @@ -357,6 +358,7 @@ class TLdapAuthProvider : public NActors::TActorBootstrapped<TLdapAuthProvider>
LDAPMessage* searchMessage = nullptr;
int result = NKikimrLdap::Search(ld, Settings.GetBaseDn(), NKikimrLdap::EScope::SUBTREE, filter, NKikimrLdap::noAttributes, 0, &searchMessage);
if (!NKikimrLdap::IsSuccess(result)) {
NKikimrLdap::MsgFree(searchMessage);
return {};
}
const int countEntries = NKikimrLdap::CountEntries(ld, searchMessage);
Expand Down Expand Up @@ -403,6 +405,7 @@ class TLdapAuthProvider : public NActors::TActorBootstrapped<TLdapAuthProvider>
LDAPMessage* searchMessage = nullptr;
int result = NKikimrLdap::Search(ld, Settings.GetBaseDn(), NKikimrLdap::EScope::SUBTREE, filter, RequestedAttributes, 0, &searchMessage);
if (!NKikimrLdap::IsSuccess(result)) {
NKikimrLdap::MsgFree(searchMessage);
return;
}
if (NKikimrLdap::CountEntries(ld, searchMessage) == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ NKikimr::TEvLdapAuthProvider::EStatus ErrorToStatus(int err) {
bool IsRetryableError(int error) {
switch (error) {
case LDAP_SERVER_DOWN:
case LDAP_TIMEOUT:
case LDAP_CONNECT_ERROR:
case LDAP_BUSY:
case LDAP_UNAVAILABLE:
case LDAP_ADMINLIMIT_EXCEEDED:
return true;
}
return false;
Expand Down
67 changes: 67 additions & 0 deletions ydb/core/security/ldap_auth_provider/ldap_auth_provider_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,61 @@ void CheckRequiredLdapSettings(std::function<void(NKikimrProto::TLdapAuthenticat
ldapServer.Stop();
}

void LdapRefreshGroupsInfoWithError(const ESecurityConnectionType& secureType) {
TString login = "ldapuser";
TString password = "ldapUserPassword";

TLdapKikimrServer server(InitLdapSettings, secureType);
auto responses = TCorrectLdapResponse::GetResponses(login);
LdapMock::TLdapMockResponses updatedResponses = responses;
LdapMock::TSearchResponseInfo responseServerBusy {
.ResponseEntries = {}, // Server is busy, can retry attempt
.ResponseDone = {.Status = LdapMock::EStatus::BUSY}
};

auto& searchResponse = responses.SearchResponses.front();
searchResponse.second = responseServerBusy;
LdapMock::TLdapSimpleServer ldapServer(server.GetLdapPort(), {responses, updatedResponses}, secureType == ESecurityConnectionType::LDAPS_SCHEME);

auto loginResponse = GetLoginResponse(server, login, password);
TTestActorRuntime* runtime = server.GetRuntime();
TActorId sender = runtime->AllocateEdgeActor();
runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
TAutoPtr<IEventHandle> handle;
TEvTicketParser::TEvAuthorizeTicketResult* ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);

// Server is busy, return retryable error
UNIT_ASSERT_C(!ticketParserResult->Error.empty(), "Expected return error message");
UNIT_ASSERT(ticketParserResult->Token == nullptr);
UNIT_ASSERT_STRINGS_EQUAL(ticketParserResult->Error.Message, "Could not login via LDAP");
UNIT_ASSERT_EQUAL(ticketParserResult->Error.Retryable, true);

Sleep(TDuration::Seconds(3));
ldapServer.UpdateResponses();
Sleep(TDuration::Seconds(7));

runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);
ticketParserResult = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);

// After refresh ticket, server return success
UNIT_ASSERT_C(ticketParserResult->Error.empty(), ticketParserResult->Error);
UNIT_ASSERT(ticketParserResult->Token != nullptr);
const TString ldapDomain = "@ldap";
UNIT_ASSERT_VALUES_EQUAL(ticketParserResult->Token->GetUserSID(), login + ldapDomain);
const auto& fetchedGroups = ticketParserResult->Token->GetGroupSIDs();
THashSet<TString> groups(fetchedGroups.begin(), fetchedGroups.end());

THashSet<TString> expectedGroups = TCorrectLdapResponse::GetAllGroups(ldapDomain);
expectedGroups.insert("all-users@well-known");

UNIT_ASSERT_VALUES_EQUAL(fetchedGroups.size(), expectedGroups.size());
for (const auto& expectedGroup : expectedGroups) {
UNIT_ASSERT_C(groups.contains(expectedGroup), "Can not find " + expectedGroup);
}

ldapServer.Stop();
}

Y_UNIT_TEST_SUITE(LdapAuthProviderTest) {
Y_UNIT_TEST(LdapServerIsUnavailable) {
CheckRequiredLdapSettings(InitLdapSettingsWithUnavailableHost, "Could not login via LDAP", ESecurityConnectionType::START_TLS);
Expand Down Expand Up @@ -1246,6 +1301,10 @@ Y_UNIT_TEST_SUITE(LdapAuthProviderTest_LdapsScheme) {
Y_UNIT_TEST(LdapRefreshRemoveUserBad) {
LdapRefreshRemoveUserBad(ESecurityConnectionType::LDAPS_SCHEME);
}

Y_UNIT_TEST(LdapRefreshGroupsInfoWithError) {
LdapRefreshGroupsInfoWithError(ESecurityConnectionType::LDAPS_SCHEME);
}
}

Y_UNIT_TEST_SUITE(LdapAuthProviderTest_StartTls) {
Expand Down Expand Up @@ -1304,6 +1363,10 @@ Y_UNIT_TEST_SUITE(LdapAuthProviderTest_StartTls) {
Y_UNIT_TEST(LdapRefreshRemoveUserBad) {
LdapRefreshRemoveUserBad(ESecurityConnectionType::START_TLS);
}

Y_UNIT_TEST(LdapRefreshGroupsInfoWithError) {
LdapRefreshGroupsInfoWithError(ESecurityConnectionType::START_TLS);
}
}

Y_UNIT_TEST_SUITE(LdapAuthProviderTest_nonSecure) {
Expand Down Expand Up @@ -1362,6 +1425,10 @@ Y_UNIT_TEST_SUITE(LdapAuthProviderTest_nonSecure) {
Y_UNIT_TEST(LdapRefreshRemoveUserBad) {
LdapRefreshRemoveUserBad(ESecurityConnectionType::NON_SECURE);
}

Y_UNIT_TEST(LdapRefreshGroupsInfoWithError) {
LdapRefreshGroupsInfoWithError(ESecurityConnectionType::NON_SECURE);
}
}

} // NKikimr
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ NKikimr::TEvLdapAuthProvider::EStatus ErrorToStatus(int err) {
bool IsRetryableError(int error) {
switch (error) {
case LDAP_SERVER_DOWN:
case LDAP_TIMEOUT:
case LDAP_CONNECT_ERROR:
case LDAP_BUSY:
case LDAP_UNAVAILABLE:
case LDAP_ADMIN_LIMIT_EXCEEDED:
return true;
}
return false;
Expand Down
60 changes: 34 additions & 26 deletions ydb/core/security/ticket_parser_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1741,12 +1741,16 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
void SetError(const TString& key, TTokenRecord& record, const TEvTicketParser::TError& error) {
record.Error = error;
TInstant now = TlsActivationContext->Now();
TStringBuilder errorLogMessage;
if (error.HasLogMessage()) {
errorLogMessage << " (" << error.LogMessage << ")";
}
if (record.Error.Retryable) {
record.ExpireTime = GetExpireTime(record, now);
record.SetErrorRefreshTime(this, now);
CounterTicketsErrorsRetryable->Inc();
BLOG_D("Ticket " << record.GetMaskedTicket() << " ("
<< record.PeerName << ") has now retryable error message '" << error.Message << "'");
<< record.PeerName << ") has now retryable error message '" << error.Message << errorLogMessage << "'");
if (record.RefreshRetryableErrorImmediately) {
record.RefreshRetryableErrorImmediately = false;
GetDerived()->CanRefreshTicket(key, record);
Expand All @@ -1759,7 +1763,7 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {
record.SetOkRefreshTime(this, now);
CounterTicketsErrorsPermanent->Inc();
BLOG_D("Ticket " << record.GetMaskedTicket() << " ("
<< record.PeerName << ") has now permanent error message '" << error.Message << "'");
<< record.PeerName << ") has now permanent error message '" << error.Message << errorLogMessage << "'");
}
CounterTicketsErrors->Inc();
record.IsLowAccessServiceRequestPriority = true;
Expand Down Expand Up @@ -1841,7 +1845,7 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {

template <typename TTokenRecord>
bool CanRefreshLoginTicket(const TTokenRecord& record) {
return record.TokenType == TDerived::ETokenType::Login && record.Error.empty();
return record.TokenType == TDerived::ETokenType::Login && record.Error.Retryable;
}

template <typename TTokenRecord>
Expand All @@ -1862,30 +1866,34 @@ class TTicketParserImpl : public TActorBootstrapped<TDerived> {

template <typename TTokenRecord>
bool RefreshLoginTicket(const TString& key, TTokenRecord& record) {
GetDerived()->ResetTokenRecord(record);
const TString userSID = record.GetToken()->GetUserSID();
if (record.IsExternalAuthEnabled()) {
return RefreshTicketViaExternalAuthProvider(key, record);
}
const TString& database = Config.GetDomainLoginOnly() ? DomainName : record.Database;
auto itLoginProvider = LoginProviders.find(database);
if (itLoginProvider == LoginProviders.end()) {
return false;
}
NLogin::TLoginProvider& loginProvider(itLoginProvider->second);
if (loginProvider.CheckUserExists(userSID)) {
const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(userSID);
const TVector<NACLib::TSID> groups(providerGroups.begin(), providerGroups.end());
SetToken(key, record, new NACLib::TUserToken({
.OriginalUserToken = record.Ticket,
.UserSID = userSID,
.GroupSIDs = groups,
.AuthType = record.GetAuthType()
}));
} else {
SetError(key, record, {.Message = "User not found", .Retryable = false});
if (record.Error.empty()) {
GetDerived()->ResetTokenRecord(record);
const TString userSID = record.GetToken()->GetUserSID();
if (record.IsExternalAuthEnabled()) {
return RefreshTicketViaExternalAuthProvider(key, record);
}
const TString& database = Config.GetDomainLoginOnly() ? DomainName : record.Database;
auto itLoginProvider = LoginProviders.find(database);
if (itLoginProvider == LoginProviders.end()) {
return false;
}
NLogin::TLoginProvider& loginProvider(itLoginProvider->second);
if (loginProvider.CheckUserExists(userSID)) {
const std::vector<TString> providerGroups = loginProvider.GetGroupsMembership(userSID);
const TVector<NACLib::TSID> groups(providerGroups.begin(), providerGroups.end());
SetToken(key, record, new NACLib::TUserToken({
.OriginalUserToken = record.Ticket,
.UserSID = userSID,
.GroupSIDs = groups,
.AuthType = record.GetAuthType()
}));
} else {
SetError(key, record, {.Message = "User not found", .Retryable = false});
}
return true;
}
return true;
GetDerived()->ResetTokenRecord(record);
return CanInitLoginToken(key, record);
}

template <typename TTokenRecord>
Expand Down
66 changes: 66 additions & 0 deletions ydb/core/security/ticket_parser_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,72 @@ Y_UNIT_TEST_SUITE(TTicketParserTest) {
UNIT_ASSERT_VALUES_EQUAL(result->Token->GetGroupSIDs().size(), 3);
}

Y_UNIT_TEST(LoginRefreshGroupsWithError) {
using namespace Tests;
TPortManager tp;
ui16 kikimrPort = tp.GetPort(2134);
ui16 grpcPort = tp.GetPort(2135);
NKikimrProto::TAuthConfig authConfig;
authConfig.SetUseBlackBox(false);
authConfig.SetUseLoginProvider(true);
authConfig.SetRefreshTime("5s");
auto settings = TServerSettings(kikimrPort, authConfig);
settings.SetDomainName("Root");
settings.CreateTicketParser = NKikimr::CreateTicketParser;
TServer server(settings);
server.EnableGRpc(grpcPort);
server.GetRuntime()->SetLogPriority(NKikimrServices::TICKET_PARSER, NLog::PRI_TRACE);
server.GetRuntime()->SetLogPriority(NKikimrServices::GRPC_CLIENT, NLog::PRI_TRACE);
TClient client(settings);
NClient::TKikimr kikimr(client.GetClientConfig());
client.InitRootScheme();
TTestActorRuntime* runtime = server.GetRuntime();

NLogin::TLoginProvider provider;

provider.Audience = "/Root";
provider.RotateKeys();

TActorId sender = runtime->AllocateEdgeActor();

provider.CreateGroup({.Group = "group1"});
provider.CreateUser({.User = "user1", .Password = "password1"});
provider.AddGroupMembership({.Group = "group1", .Member = "user1"});

NLogin::TLoginProvider emptyProvider;
emptyProvider.Audience = "/Root";

runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvUpdateLoginSecurityState(emptyProvider.GetSecurityState())), 0);

auto loginResponse = provider.LoginUser({.User = "user1", .Password = "password1"});

UNIT_ASSERT_VALUES_EQUAL(loginResponse.Error, "");

runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);

TAutoPtr<IEventHandle> handle;

TEvTicketParser::TEvAuthorizeTicketResult* result = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);
UNIT_ASSERT_C(!result->Error.empty(), "Expected return error message");
UNIT_ASSERT(result->Token == nullptr);
UNIT_ASSERT_STRINGS_EQUAL(result->Error.Message, "Security state is empty");
UNIT_ASSERT_EQUAL(result->Error.Retryable, true);

Sleep(TDuration::Seconds(3));
runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvUpdateLoginSecurityState(provider.GetSecurityState())), 0);
Sleep(TDuration::Seconds(7));

runtime->Send(new IEventHandle(MakeTicketParserID(), sender, new TEvTicketParser::TEvAuthorizeTicket(loginResponse.Token)), 0);

result = runtime->GrabEdgeEvent<TEvTicketParser::TEvAuthorizeTicketResult>(handle);

UNIT_ASSERT_C(result->Error.empty(), result->Error);
UNIT_ASSERT(result->Token != nullptr);
UNIT_ASSERT_VALUES_EQUAL(result->Token->GetUserSID(), "user1");
UNIT_ASSERT(result->Token->IsExist("group1"));
UNIT_ASSERT_VALUES_EQUAL(result->Token->GetGroupSIDs().size(), 2);
}

Y_UNIT_TEST(LoginCheckRemovedUser) {
using namespace Tests;
TPortManager tp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum EStatus {
SUCCESS = 0x00,
PROTOCOL_ERROR = 0x02,
INVALID_CREDENTIALS = 0x31,
BUSY = 0x33,
};

enum EProtocolOp {
Expand Down
Loading