Skip to content

Commit

Permalink
Fix #88: Implement Connection Failover (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
sitingren authored Jan 26, 2021
1 parent a13a5d3 commit ab170fe
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Currently supported query arguments are:
| tlsmode | the ssl/tls policy for this connection | 'none' (default) = don't use SSL/TLS for this connection |
| | | 'server' = server must support SSL/TLS, but skip verification **(INSECURE!)** |
| | | 'server-strict' = server must support SSL/TLS |
| backup_server_node | a list of backup hosts for the client to try to connect if the primary host is unreachable | a comma-seperated list of backup host-port pairs. E.g.<br> 'host1:port1,host2:port2,host3:port3' |

To ping the server and validate a connection (as the connection isn't necessarily created at that moment), simply call the *PingContext()* method.

Expand Down
50 changes: 47 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type connection struct {
cancelKey uint32
transactionState byte
usePreparedStmts bool
connHostsList []string
scratch [512]byte
sessionID string
serverTZOffset string
Expand Down Expand Up @@ -185,15 +186,27 @@ func newConnection(connString string) (*connection, error) {
// Read connection load balance flag.
loadBalanceFlag := result.connURL.Query().Get("connection_load_balance")

// Read connection failover flag.
backupHostsStr := result.connURL.Query().Get("backup_server_node")
if backupHostsStr == "" {
result.connHostsList = []string{result.connURL.Host}
} else {
// Parse comma-seperated list of backup host-port pairs
hosts := strings.Split(backupHostsStr, ",")
// Push target host to front of the hosts list
result.connHostsList = append([]string{result.connURL.Host}, hosts...)
}

// Read SSL/TLS flag.
sslFlag := strings.ToLower(result.connURL.Query().Get("tlsmode"))
if sslFlag == "" {
sslFlag = "none"
}

result.conn, err = net.Dial("tcp", result.connURL.Host)
result.conn, err = result.establishSocketConnection()

if err != nil {
return nil, fmt.Errorf("cannot connect to %s (%s)", result.connURL.Host, err.Error())
return nil, err
}

// Load Balancing
Expand All @@ -220,6 +233,28 @@ func newConnection(connString string) (*connection, error) {
return result, nil
}

func (v *connection) establishSocketConnection() (net.Conn, error) {
// Failover: loop to try all hosts in the list
err_msg := ""
for i := 0; i < len(v.connHostsList); i++ {
// net.Dial will resolve the host to multiple IP addresses,
// and try each IP address in order until one succeeds.
conn, err := net.Dial("tcp", v.connHostsList[i])
if err != nil {
err_msg += fmt.Sprintf("\n '%s': %s", v.connHostsList[i], err.Error())
} else {
if len(err_msg) != 0 {
connectionLogger.Debug("Failed to establish a connection to %s", err_msg)
}
connectionLogger.Debug("Established socket connection to %s", v.connHostsList[i])
v.connHostsList = v.connHostsList[i:]
return conn, err
}
}
// All of the hosts failed
return nil, fmt.Errorf("Failed to establish a connection to the primary server or any backup host.%s", err_msg)
}

func (v *connection) recvMessage() (msgs.BackEndMsg, error) {
msgHeader := v.scratch[:5]

Expand Down Expand Up @@ -494,9 +529,18 @@ func (v *connection) balanceLoad() error {
// v.connURL.Hostname() is used by initializeSSL(), so load balancing info should not write into v.connURL
loadBalanceAddr := fmt.Sprintf("%s:%d", msg.Host, msg.Port)

if v.connHostsList[0] == loadBalanceAddr {
// Already connecting to the host
return nil
}

// Push the new host onto the host list before connecting again.
// Note that this leaves the originally-specified host as the first failover possibility
v.connHostsList = append([]string{loadBalanceAddr}, v.connHostsList...)

// Connect to new host
v.conn.Close()
v.conn, err = net.Dial("tcp", loadBalanceAddr)
v.conn, err = v.establishSocketConnection()

if err != nil {
return fmt.Errorf("cannot redirect to %s (%s)", loadBalanceAddr, err.Error())
Expand Down
21 changes: 16 additions & 5 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ import (
)

var (
testLogger = logger.New("test")
myDBConnectString string
otherConnectString string
badConnectString string
ctx context.Context
testLogger = logger.New("test")
myDBConnectString string
otherConnectString string
badConnectString string
failoverConnectString string
ctx context.Context
)

func assertTrue(t *testing.T, v bool) {
Expand Down Expand Up @@ -403,6 +404,15 @@ func TestTransaction(t *testing.T) {
assertNoErr(t, tx.Rollback())
}

func TestConnFailover(t *testing.T) {
// Connection string's "backup_server_node" parameter contains the correct host
connDB, err := sql.Open("vertica", failoverConnectString)
assertNoErr(t, err)

assertNoErr(t, connDB.PingContext(ctx))
assertNoErr(t, connDB.Close())
}

func TestPWAuthentication(t *testing.T) {
connDB := openConnection(t, "test_pw_authentication_pre")
defer closeConnection(t, connDB, "test_pw_authentication_post")
Expand Down Expand Up @@ -1094,6 +1104,7 @@ func init() {
myDBConnectString = "vertica://" + *verticaUserName + ":" + *verticaPassword + "@" + *verticaHostPort + "/" + *verticaUserName + "?" + usePreparedStmtsString + "&tlsMode=" + *tlsMode
otherConnectString = "vertica://TestGuy:TestGuyPass@" + *verticaHostPort + "/TestGuy?tlsmode=" + *tlsMode
badConnectString = "vertica://TestGuy:TestGuyBadPass@" + *verticaHostPort + "/TestGuy?tlsmode=" + *tlsMode
failoverConnectString = "vertica://" + *verticaUserName + ":" + *verticaPassword + "@badHost" + "/" + *verticaUserName + "?backup_server_node=abc.com:100000," + *verticaHostPort + ",localhost:port"

ctx = context.Background()
}
3 changes: 1 addition & 2 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (
"fmt"
"io"
"math/rand"
"net"
"os"
"reflect"
"regexp"
Expand Down Expand Up @@ -252,7 +251,7 @@ func (s *stmt) QueryContextRaw(ctx context.Context, baseArgs []driver.NamedValue
case <-ctx.Done():
stmtLogger.Info("Context cancelled, cancelling %s", s.preparedName)
cancelMsg := msgs.FECancelMsg{PID: pid, Key: key}
conn, err := net.Dial("tcp", s.conn.connURL.Host)
conn, err := s.conn.establishSocketConnection()
if err != nil {
stmtLogger.Warn("unable to establish connection for cancellation")
return
Expand Down

0 comments on commit ab170fe

Please sign in to comment.