package ydbgoquery
import (
"context"
"errors"
"sync"
"github.com/adwski/ydb-go-query/internal/discovery"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/query"
"github.com/adwski/ydb-go-query/internal/transport"
balancing "github.com/adwski/ydb-go-query/internal/transport/balancing/v4"
"github.com/adwski/ydb-go-query/internal/transport/dispatcher"
qq "github.com/adwski/ydb-go-query/query"
)
var (
ErrNoInitialNodes = errors.New("no initial nodes was provided")
ErrDBEmpty = errors.New("db is empty")
ErrDiscoveryTransportCreate = errors.New("discovery transport create error")
)
func Open(ctx context.Context, cfg Config, opts ...Option) (*Client, error) {
client, err := newClient(ctx, &cfg, opts...)
if err != nil {
return nil, err
}
var runCtx context.Context
runCtx, client.cancel = context.WithCancel(ctx)
client.querySvc = query.NewService(runCtx, query.Config{
Logger: client.logger,
Transport: client.dispatcher.Transport(),
CreateTimeout: cfg.sessionCreateTimeout,
PoolSize: cfg.poolSize,
PoolReadyThresholdHigh: cfg.poolReadyHi,
PoolReadyThresholdLow: cfg.poolReadyLo,
})
client.queryCtx = qq.NewCtx(client.logger, client.querySvc, cfg.txSettings, cfg.queryTimeout)
client.wg.Add(1)
go client.dispatcher.Run(runCtx, client.wg)
client.wg.Add(1)
go client.discoverySvc.Run(runCtx, client.wg)
if cfg.auth != nil {
client.wg.Add(1)
go cfg.auth.Run(runCtx, client.wg)
}
return client, nil
}
type (
authRunner interface {
transport.Authenticator
Run(ctx context.Context, wg *sync.WaitGroup)
}
Client struct {
dispatcher *dispatcher.Dynamic
discoverySvc *discovery.Service
querySvc *query.Service
queryCtx *qq.Ctx
wg *sync.WaitGroup
cancel context.CancelFunc
logger logger.Logger
}
)
func (c *Client) QueryCtx() *qq.Ctx {
return c.queryCtx
}
func (c *Client) Close() {
c.cancel()
_ = c.querySvc.Close()
c.wg.Wait()
}
func (c *Client) Ready() bool {
return c.querySvc.Ready()
}
func newClient(ctx context.Context, cfg *Config, opts ...Option) (*Client, error) {
if len(cfg.InitialNodes) == 0 {
return nil, ErrNoInitialNodes
}
if len(cfg.DB) == 0 {
return nil, ErrDBEmpty
}
cfg.setDefaults()
for _, opt := range opts {
if err := opt(ctx, cfg); err != nil {
return nil, err
}
}
tr, err := dispatcher.NewStatic(ctx, cfg.InitialNodes, cfg.transportCredentials, cfg.auth, cfg.DB)
if err != nil {
return nil, errors.Join(ErrDiscoveryTransportCreate, err)
}
discoverySvc := discovery.NewService(discovery.Config{
Logger: cfg.logger,
DB: cfg.DB,
Transport: tr.Transport(),
DoAnnounce: true,
})
dispatcherCfg := dispatcher.Config{
Logger: cfg.logger,
InitNodes: cfg.InitialNodes,
DB: cfg.DB,
Balancing: balancing.Config{
LocationPreference: cfg.locationPreference,
ConnsPerEndpoint: cfg.connectionsPerEndpoint,
IgnoreLocations: false,
},
TransportCredentials: cfg.transportCredentials,
Auth: cfg.auth,
EndpointsProvider: discoverySvc,
}
c := &Client{
logger: cfg.logger,
dispatcher: dispatcher.NewDynamic(dispatcherCfg),
discoverySvc: discoverySvc,
wg: &sync.WaitGroup{},
}
return c, nil
}
package ydbgoquery
import (
"context"
"errors"
"strings"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/logger/noop"
zaplogger "github.com/adwski/ydb-go-query/internal/logger/zap"
zerologger "github.com/adwski/ydb-go-query/internal/logger/zerolog"
"github.com/adwski/ydb-go-query/internal/query/txsettings"
"github.com/adwski/ydb-go-query/internal/transport/auth"
"github.com/adwski/ydb-go-query/internal/transport/auth/userpass"
"github.com/adwski/ydb-go-query/internal/transport/auth/yc"
transportCreds "github.com/adwski/ydb-go-query/internal/transport/credentials"
"github.com/adwski/ydb-go-query/internal/transport/dispatcher"
"github.com/rs/zerolog"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
"go.uber.org/zap"
"google.golang.org/grpc/credentials"
)
const (
defaultSessionCreateTimeout = 3 * time.Second
defaultQueryTimeout = 5 * time.Minute
defaultSessionPoolSize = 10
defaultConnectionsPerEndpoint = 2
)
var (
ErrAuthentication = errors.New("authentication failed")
)
type (
Config struct {
logger logger.Logger
transportCredentials credentials.TransportCredentials
auth authRunner
txSettings *Ydb_Query.TransactionSettings
DB string
InitialNodes []string
locationPreference []string
poolSize uint
poolReadyHi uint
poolReadyLo uint
connectionsPerEndpoint int
sessionCreateTimeout time.Duration
queryTimeout time.Duration
}
Option func(context.Context, *Config) error
)
func (cfg *Config) setDefaults() {
cfg.logger = logger.New(noop.NewLogger())
cfg.sessionCreateTimeout = defaultSessionCreateTimeout
cfg.queryTimeout = defaultQueryTimeout
cfg.poolSize = defaultSessionPoolSize
cfg.connectionsPerEndpoint = defaultConnectionsPerEndpoint
cfg.transportCredentials = transportCreds.Insecure()
cfg.txSettings = txsettings.SerializableReadWrite()
}
func WithLogger(log logger.Logger) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.logger = log
return nil
}
}
func WithZeroLogger(log zerolog.Logger, level string) Option {
return func(ctx context.Context, cfg *Config) error {
lg, err := logger.NewWithLevel(zerologger.NewLogger(log), level)
if err != nil {
return err //nolint:wrapcheck // unnecessary
}
cfg.logger = lg
return nil
}
}
func WithZapLogger(log *zap.Logger, level string) Option {
return func(ctx context.Context, cfg *Config) error {
lg, err := logger.NewWithLevel(zaplogger.NewLogger(log), level)
if err != nil {
return err //nolint:wrapcheck // unnecessary
}
cfg.logger = lg
return nil
}
}
func WithSessionCreateTimeout(timeout time.Duration) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.sessionCreateTimeout = timeout
return nil
}
}
func WithQueryTimeout(timeout time.Duration) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.queryTimeout = timeout
return nil
}
}
func WithSessionPoolSize(size uint) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.poolSize = size
return nil
}
}
func WithSessionPoolReadyThresholds(high, low uint) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.poolReadyHi = high
cfg.poolReadyLo = low
return nil
}
}
func WithLocationPreference(pref string) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.locationPreference = strings.Split(pref, ",")
for idx := range cfg.locationPreference {
cfg.locationPreference[idx] = strings.TrimSpace(cfg.locationPreference[idx])
}
return nil
}
}
func WithConnectionsPerEndpoint(connections int) Option {
return func(ctx context.Context, cfg *Config) error {
if connections > 0 {
cfg.connectionsPerEndpoint = connections
}
return nil
}
}
func withTransportSecurity(credentials credentials.TransportCredentials) Option {
return func(ctx context.Context, cfg *Config) error {
cfg.transportCredentials = credentials
return nil
}
}
func WithTransportTLS() Option {
return withTransportSecurity(transportCreds.TLS())
}
func WithYCAuthFile(filename string) Option {
return withYC(yc.Config{
IamKeyFile: filename,
})
}
func WithYCAuthBytes(iamKeyBytes []byte) Option {
return withYC(yc.Config{
IamKey: iamKeyBytes,
})
}
func withYC(ycCfg yc.Config) Option {
return func(ctx context.Context, cfg *Config) error {
ycAuth, err := yc.New(ctx, ycCfg)
if err != nil {
return err //nolint:wrapcheck // unnecessary
}
cfg.auth, err = auth.New(ctx, auth.Config{
Logger: cfg.logger,
Provider: ycAuth,
})
if err != nil {
return errors.Join(ErrAuthentication, err)
}
return nil
}
}
var ErrAuthTransport = errors.New("unable to create auth transport")
func WithUserPass(username, password string) Option {
return func(ctx context.Context, cfg *Config) error {
tr, err := dispatcher.NewStatic(ctx, cfg.InitialNodes, cfg.transportCredentials, nil, cfg.DB)
if err != nil {
return errors.Join(ErrAuthTransport, err)
}
cfg.auth, err = auth.New(ctx, auth.Config{
Logger: cfg.logger,
Provider: userpass.New(userpass.Config{
Transport: tr.Transport(),
Username: username,
Password: password,
}),
})
if err != nil {
return errors.Join(ErrAuthentication, err)
}
return nil
}
}
func WithSerializableReadWrite() Option {
return func(ctx context.Context, cfg *Config) error {
cfg.txSettings = txsettings.SerializableReadWrite()
return nil
}
}
func WithOnlineReadOnly() Option {
return func(ctx context.Context, cfg *Config) error {
cfg.txSettings = txsettings.OnlineReadOnly()
return nil
}
}
func WithOnlineReadOnlyInconsistent() Option {
return func(ctx context.Context, cfg *Config) error {
cfg.txSettings = txsettings.OnlineReadOnlyInconsistent()
return nil
}
}
func WithStaleReadOnly() Option {
return func(ctx context.Context, cfg *Config) error {
cfg.txSettings = txsettings.StaleReadOnly()
return nil
}
}
func WithSnapshotReadOnly() Option {
return func(ctx context.Context, cfg *Config) error {
cfg.txSettings = txsettings.SnapshotReadOnly()
return nil
}
}
package discovery
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/adwski/ydb-go-query/internal/endpoints"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Discovery_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
"google.golang.org/grpc"
)
var (
ErrEndpointsList = errors.New("unable to get endpoints")
ErrEndpointsUnmarshal = errors.New("unable to unmarshal endpoints")
ErrOperationUnsuccessful = errors.New("operation unsuccessful")
)
const (
discoveryTimeout = 3 * time.Second
discoveryInterval = 30 * time.Second
discoveryErrRetry = 2 * time.Second
)
type (
Service struct {
logger logger.Logger
dsc Ydb_Discovery_V1.DiscoveryServiceClient
ann chan endpoints.Announce
filter *endpoints.Filter
epDB endpoints.DB
dbName string
}
Config struct {
Logger logger.Logger
Transport grpc.ClientConnInterface
DB string
DoAnnounce bool
}
)
func NewService(cfg Config) *Service {
svc := &Service{
dbName: cfg.DB,
logger: cfg.Logger,
filter: endpoints.NewFilter().WithQueryService(),
dsc: Ydb_Discovery_V1.NewDiscoveryServiceClient(cfg.Transport),
epDB: endpoints.NewDB(),
}
if cfg.DoAnnounce {
svc.ann = make(chan endpoints.Announce)
}
return svc
}
func (svc *Service) EndpointsChan() <-chan endpoints.Announce {
return svc.ann
}
func (svc *Service) GetAllEndpoints() endpoints.Map {
return svc.epDB.GetAll()
}
func (svc *Service) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
waitTimer := svc.endpointsTick(ctx, nil)
defer waitTimer.Stop()
runLoop:
for {
select {
case <-ctx.Done():
break runLoop
case <-waitTimer.C:
svc.endpointsTick(ctx, waitTimer)
}
}
}
func (svc *Service) endpointsTick(ctx context.Context, waitTimer *time.Timer) *time.Timer {
ctxEp, cancelEp := context.WithDeadline(ctx, time.Now().Add(discoveryTimeout))
defer cancelEp()
timerInterval := discoveryInterval
if eps, err := svc.getEndpoints(ctxEp); err != nil {
svc.logger.Error("getEndpoints failed", "error", err, "db", svc.dbName)
timerInterval = discoveryErrRetry
} else {
svc.logger.Debug("getEndpoints succeeded", "count", len(eps))
svc.updateAndAnnounce(ctx, eps)
}
if waitTimer == nil {
return time.NewTimer(timerInterval)
}
waitTimer.Reset(timerInterval)
return nil
}
func (svc *Service) updateAndAnnounce(ctx context.Context, endpoints []*Ydb_Discovery.EndpointInfo) {
if svc.epDB.Compare(endpoints) {
// endpoints did not change
return
}
announce, oldLen, newLen := svc.epDB.Update(endpoints)
svc.logger.Info("endpoints changed",
"was", oldLen,
"now", newLen,
"new", len(announce.Add),
"old", len(announce.Del))
if svc.ann == nil {
return
}
select {
case <-ctx.Done():
case svc.ann <- announce:
}
}
func (svc *Service) getEndpoints(ctx context.Context) ([]*Ydb_Discovery.EndpointInfo, error) {
resp, err := svc.dsc.ListEndpoints(ctx, &Ydb_Discovery.ListEndpointsRequest{
Database: svc.dbName,
})
if err != nil {
return nil, errors.Join(ErrEndpointsList, err)
}
status := resp.GetOperation().GetStatus()
if status != Ydb.StatusIds_SUCCESS {
return nil, errors.Join(ErrOperationUnsuccessful,
fmt.Errorf("%s", resp.GetOperation().String()))
}
var epRes Ydb_Discovery.ListEndpointsResult
if err = resp.GetOperation().GetResult().UnmarshalTo(&epRes); err != nil {
return nil, errors.Join(ErrEndpointsUnmarshal, err)
}
preferred, requiredButNotPreferred := svc.filter.Filter(epRes.Endpoints)
if len(preferred) == 0 {
return requiredButNotPreferred, nil
}
return preferred, nil
}
package endpoints
import (
"sync"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
)
type (
// Announce consolidates changes observed in YDB endpoints since previous observation.
Announce struct {
Add Map // contains newly discovered endpoints
Update Map // contains endpoints with changes (reserved for later use with load factor)
Del []InfoShort // contains endpoints that are no longer present in YDB cluster
}
// Map stores endpoints as kay-value structure.
Map map[InfoShort]*Ydb_Discovery.EndpointInfo
// DB is thread safe in-memory storage for endpoints.
DB struct {
mx *sync.RWMutex
dbm Map
}
)
// NewDB creates endpoints DB.
func NewDB() DB {
return DB{
mx: &sync.RWMutex{},
dbm: make(Map),
}
}
// GetAll returns copy of internal endpoints Map.
func (db *DB) GetAll() Map {
db.mx.RLock()
defer db.mx.RUnlock()
eps := make(Map, len(db.dbm))
for k, v := range db.dbm {
eps[k] = v
}
return eps
}
// Compare takes current state of endpoints and compares it
// with internal endpoints Map. It returns true if incoming state
// is identical to internal or false otherwise.
func (db *DB) Compare(endpoints []*Ydb_Discovery.EndpointInfo) bool {
db.mx.RLock()
defer db.mx.RUnlock()
ctr := len(db.dbm)
for _, ep := range endpoints {
if _, ok := db.dbm[NewInfoShort(ep)]; !ok {
return false
}
ctr--
}
return ctr == 0
}
// Update takes current state of endpoints and
// - updates internal DB accordingly
// - constructs endpoints announcement that reflects performed changes.
func (db *DB) Update(endpoints []*Ydb_Discovery.EndpointInfo) (Announce, int, int) {
oldDB := db.GetAll()
newDB := make(Map, len(endpoints))
prev := len(oldDB)
length := len(endpoints)
ann := Announce{
Add: make(Map, length),
// Update: make(Map, length), // TODO
Del: make([]InfoShort, 0, length),
}
for _, ep := range endpoints {
key := NewInfoShortFromParams(
ep.Location,
ep.Address,
ep.NodeId,
ep.Port,
)
if _, ok := db.dbm[key]; !ok {
ann.Add[key] = ep
}
newDB[key] = ep
}
for k := range oldDB {
if _, ok := newDB[k]; !ok {
ann.Del = append(ann.Del, k)
}
}
db.swap(newDB)
return ann, prev, length
}
func (db *DB) swap(dbm Map) {
db.mx.Lock()
defer db.mx.Unlock()
db.dbm = dbm
}
package endpoints
import (
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
)
const (
serviceNameQuery = "query_service"
)
type (
Filter struct {
Require *Require
Prefer *Prefer
}
Require struct {
Services []string
Locations []string
}
Prefer struct {
Locations []string
}
)
func NewFilter() *Filter {
return &Filter{
Require: &Require{},
}
}
func (f *Filter) WithQueryService() *Filter {
f.Require.Services = append(f.Require.Services, serviceNameQuery)
return f
}
func (f *Filter) matchRequired(ep *Ydb_Discovery.EndpointInfo) bool {
if f.Require == nil {
return true
}
if !matchServices(ep, f.Require.Services) {
return false
}
return matchLocation(ep, f.Require.Locations)
}
func (f *Filter) matchPreferred(ep *Ydb_Discovery.EndpointInfo) bool {
if f.Prefer == nil {
return true
}
return matchLocation(ep, f.Prefer.Locations)
}
func (f *Filter) Filter(endpoints []*Ydb_Discovery.EndpointInfo) (
preferred []*Ydb_Discovery.EndpointInfo,
notPreferred []*Ydb_Discovery.EndpointInfo,
) {
for _, ep := range endpoints {
if f.matchRequired(ep) {
if f.matchPreferred(ep) {
preferred = append(preferred, ep)
} else {
notPreferred = append(notPreferred, ep)
}
}
}
return
}
func matchServices(ep *Ydb_Discovery.EndpointInfo, services []string) bool {
srvs := make(map[string]struct{})
for _, srv := range ep.Service {
srvs[srv] = struct{}{}
}
for _, srv := range services {
if _, ok := srvs[srv]; !ok {
return false
}
}
return true
}
func matchLocation(ep *Ydb_Discovery.EndpointInfo, locations []string) bool {
if len(locations) == 0 {
return true
}
matchLoc := false
for _, loc := range locations {
if loc == ep.Location {
matchLoc = true
}
}
return matchLoc
}
package endpoints
import (
"hash/maphash"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
)
var (
hashSeed = maphash.MakeSeed()
)
type (
// InfoShort uniquely identifies YDB endpoint.
InfoShort struct {
Address string
Location string
AddressHash uint64
NodeID uint32
Port uint32
}
)
func NewInfoShort(ep *Ydb_Discovery.EndpointInfo) InfoShort {
return InfoShort{
NodeID: ep.NodeId,
Location: ep.Location,
Address: ep.Address,
Port: ep.Port,
AddressHash: maphash.String(hashSeed, ep.Address),
}
}
func NewInfoShortFromParams(location, address string, nodeID, port uint32) InfoShort {
return InfoShort{
NodeID: nodeID,
Location: location,
Address: address,
Port: port,
AddressHash: maphash.String(hashSeed, address),
}
}
func (eis *InfoShort) GetAddress() string {
return eis.Address
}
func (eis *InfoShort) GetPort() uint32 {
return eis.Port
}
package errors
const (
errLocalFailure = "local failure"
)
// LocalFailureError error is used globally to distinguish locally originated
// request errors from io errors or remote side errors.
type LocalFailureError struct {
}
func (e LocalFailureError) Error() string {
return errLocalFailure
}
package logger
import "errors"
const (
levelTrace = iota - 2
levelDebug
levelInfo
levelError
)
var (
ErrInvalidLevel = errors.New("invalid log level")
)
type (
External interface {
Error(string, []any)
Info(string, []any)
Debug(string, []any)
Trace(string, []any)
}
Logger struct {
ext External
lvl int
}
)
func parseLevel(level string) (int, error) {
switch level {
case "trace":
return levelTrace, nil
case "debug":
return levelDebug, nil
case "info":
return levelInfo, nil
case "error":
return levelError, nil
default:
return 0, ErrInvalidLevel
}
}
func New(ext External) Logger {
return Logger{ext: ext}
}
func NewWithLevel(ext External, level string) (Logger, error) {
lvl, err := parseLevel(level)
if err != nil {
return Logger{}, err
}
l := New(ext)
l.lvl = lvl
return l, nil
}
func (l *Logger) Level(lvl int) {
l.lvl = lvl
}
func (l *Logger) Trace(msg string, args ...any) {
if l.lvl > levelTrace {
return
}
l.ext.Trace(msg, args)
}
func (l *Logger) Debug(msg string, args ...any) {
if l.lvl > levelDebug {
return
}
l.ext.Debug(msg, args)
}
func (l *Logger) Info(msg string, args ...any) {
if l.lvl > levelInfo {
return
}
l.ext.Info(msg, args)
}
func (l *Logger) Error(msg string, args ...any) {
if l.lvl > levelError {
return
}
l.ext.Error(msg, args)
}
func (l *Logger) TraceFunc(f func() (string, []any)) {
if l.lvl > levelTrace {
return
}
l.ext.Trace(f())
}
func (l *Logger) DebugFunc(f func() (string, []any)) {
if l.lvl > levelDebug {
return
}
l.ext.Debug(f())
}
func (l *Logger) InfoFunc(f func() (string, []any)) {
if l.lvl > levelInfo {
return
}
l.ext.Info(f())
}
func (l *Logger) ErrorFunc(f func() (string, []any)) {
if l.lvl > levelError {
return
}
l.ext.Error(f())
}
package noop
type Logger struct {
}
func NewLogger() *Logger {
return &Logger{}
}
func (l *Logger) Error(string, []any) {
}
func (l *Logger) Info(string, []any) {
}
func (l *Logger) Debug(string, []any) {
}
func (l *Logger) Trace(string, []any) {
}
package zap
import (
"fmt"
"go.uber.org/zap"
)
type Logger struct {
*zap.Logger
}
func NewLogger(logger *zap.Logger) *Logger {
return &Logger{
Logger: logger.WithOptions(zap.AddCallerSkip(1)),
}
}
func (l *Logger) Error(msg string, fields []any) {
l.Logger.Error(msg, zapFields(fields)...)
}
func (l *Logger) Info(msg string, fields []any) {
l.Logger.Info(msg, zapFields(fields)...)
}
func (l *Logger) Debug(msg string, fields []any) {
l.Logger.Debug(msg, zapFields(fields)...)
}
func (l *Logger) Trace(msg string, fields []any) {
l.Logger.Debug(msg, zapFields(fields)...)
}
func zapFields(fields []any) []zap.Field {
zfs := make([]zap.Field, 0, len(fields)/2)
for i := 0; i < len(fields); i += 2 {
key, ok := fields[i].(string)
if !ok {
continue
}
switch val := fields[i+1].(type) {
case fmt.Stringer:
zfs = append(zfs, zap.String(key, val.String()))
case error:
zfs = append(zfs, zap.String(key, val.Error()))
default:
zfs = append(zfs, zap.Any(key, val))
}
}
return zfs
}
package zerolog
import (
"fmt"
"github.com/rs/zerolog"
)
type Logger struct {
zerolog.Logger
}
func NewLogger(logger zerolog.Logger) *Logger {
return &Logger{
Logger: logger,
}
}
func (l *Logger) Error(msg string, fields []any) {
emit(l.Logger.Error(), msg, fields)
}
func (l *Logger) Info(msg string, fields []any) {
emit(l.Logger.Info(), msg, fields)
}
func (l *Logger) Debug(msg string, fields []any) {
emit(l.Logger.Debug(), msg, fields)
}
func (l *Logger) Trace(msg string, fields []any) {
emit(l.Logger.Trace(), msg, fields)
}
func emit(ev *zerolog.Event, msg string, fields []any) {
if len(fields)%2 != 0 {
fields = fields[:len(fields)-1]
}
for i := 0; i < len(fields); i += 2 {
key, ok := fields[i].(string)
if !ok {
continue
}
switch val := fields[i+1].(type) {
case fmt.Stringer:
ev = ev.Str(key, val.String())
case error:
ev = ev.Err(val)
default:
ev = ev.Any(key, val)
}
}
ev.Msg(msg)
}
package pool
import (
"context"
"errors"
"math"
"math/rand"
"sync"
"sync/atomic"
"time"
localErrs "github.com/adwski/ydb-go-query/internal/errors"
"github.com/adwski/ydb-go-query/internal/logger"
)
const (
defaultCreateTimeout = 3 * time.Second
defaultRecycleTick = 2 * time.Second
minCreateTimeout = time.Second
minItemLifetime = 5 * time.Minute
minPoolSize = 1
defaultCreateRetryDelayOnLocalErrors = time.Second
defaultReadyThresholdHigh = 50 // percent
defaultReadyThresholdLow = 0 // percent
)
type (
item[T any] interface {
*T
ID() uint64
Alive() bool
Close() error
}
Pool[PT item[T], T any] struct {
createFunc func(context.Context, time.Duration) (PT, error)
cancelFunc context.CancelFunc
wg *sync.WaitGroup
closeOnce *sync.Once
queue chan PT
tokens chan struct{}
itemsExpire map[uint64]int64
itemsMx *sync.RWMutex
stats stats
logger logger.Logger
createTimeout time.Duration
itemLifetime int64 // seconds
recycleWindow int64 // seconds
size uint
closed atomic.Bool
itemRecycling bool
}
// Config holds pool configuration.
Config[PT item[T], T any] struct {
// CreateFunc is used to create pool item.
// Timeout is not set as context.WithTimeout
// because this is running context for long-lived item.
// Timeout itself should limit only creation steps,
// and it is responsibility of CreateFunc to handle it appropriately.
CreateFunc func(ctx context.Context, createTimeout time.Duration) (PT, error)
Logger logger.Logger
// CreateTimeout limits runtime for CreateFunc.
// This timeout cannot be less than a second (minCreateTimeout).
// Default is 3 seconds (defaultCreateTimeout).
CreateTimeout time.Duration
// Lifetime specifies item lifetime after which it will be closed
// and new item will be created instead.
// 0 lifetime means item has infinite lifetime and item recycling
// is not running.
// Lifetime cannot be less than 5 seconds (minItemLifetime).
Lifetime time.Duration
// RecycleWindow specifies time interval for item recycling:
// [Lifetime-RecycleWindow;Lifetime+RecycleWindow]
// This prevents service degradation caused by recycling of
// significant number of items created at the same time.
RecycleWindow time.Duration
// PoolSize specifies amount of items in pool.
PoolSize uint
// Ready thresholds specifies transition points (in percents) for ready status.
// If amount of inUse + idle sessions is greater or equal than
// high threshold then pool is Ready.
// If this amount is equal or less than low threshold then pool is NotReady.
// Thresholds should be in range [0;100] and satisfy lo < hi condition
// (must not be equal!). If these conditions are not met, pool will fall back
// to default lo=0, hi=50 values.
ReadyThresholdPercentHigh uint
ReadyThresholdPercentLow uint
hi, lo int64
test bool
}
)
func (cfg *Config[PT, T]) validate() {
if !cfg.test { // bypass min value checks
if cfg.CreateTimeout < minCreateTimeout {
cfg.CreateTimeout = defaultCreateTimeout
}
if cfg.Lifetime < minItemLifetime {
cfg.Lifetime = 0 // infinite lifetime
}
if cfg.PoolSize < minPoolSize {
cfg.PoolSize = minPoolSize
}
}
if cfg.ReadyThresholdPercentLow > 100 {
cfg.ReadyThresholdPercentLow = defaultReadyThresholdHigh
}
if cfg.ReadyThresholdPercentHigh > 100 {
cfg.ReadyThresholdPercentHigh = defaultReadyThresholdHigh
}
if cfg.ReadyThresholdPercentHigh <= cfg.ReadyThresholdPercentLow {
cfg.ReadyThresholdPercentLow = defaultReadyThresholdLow
cfg.ReadyThresholdPercentHigh = defaultReadyThresholdHigh
}
// convert from percents to actual values
cfg.hi = int64(math.Ceil(float64(cfg.ReadyThresholdPercentHigh) * float64(cfg.PoolSize) / 100))
cfg.lo = int64(math.Floor(float64(cfg.ReadyThresholdPercentLow) * float64(cfg.PoolSize) / 100))
}
func New[PT item[T], T any](ctx context.Context, cfg Config[PT, T]) *Pool[PT, T] {
cfg.validate()
runCtx, cancel := context.WithCancel(ctx)
pool := &Pool[PT, T]{
logger: cfg.Logger,
size: cfg.PoolSize,
createTimeout: cfg.CreateTimeout,
itemLifetime: cfg.Lifetime.Milliseconds() / 1000,
recycleWindow: cfg.RecycleWindow.Milliseconds() / 1000,
itemRecycling: cfg.Lifetime != 0,
createFunc: cfg.CreateFunc,
cancelFunc: cancel,
wg: &sync.WaitGroup{},
closeOnce: &sync.Once{},
itemsExpire: make(map[uint64]int64),
itemsMx: &sync.RWMutex{},
queue: make(chan PT, cfg.PoolSize),
tokens: make(chan struct{}, cfg.PoolSize),
stats: newStats(cfg.hi, cfg.lo),
}
// fill tokens
for i := 0; i < int(cfg.PoolSize); i++ {
pool.tokens <- struct{}{}
}
// start spawner
pool.wg.Add(1)
go pool.spawnItems(runCtx)
if pool.itemRecycling {
// start recycler
pool.wg.Add(1)
go pool.recycleItems(runCtx)
}
pool.logger.Debug("pool created", "size", pool.size)
return pool
}
func (p *Pool[PT, T]) Ready() bool {
if p.closed.Load() {
return false
}
return p.stats.ready().Get()
}
func (p *Pool[PT, T]) Close() error {
p.closeOnce.Do(func() {
p.closed.Store(true)
p.cancelFunc()
p.drain()
p.wg.Wait()
p.logger.Debug("pool closed")
})
return nil
}
func (p *Pool[PT, T]) Get(rCtx context.Context) PT {
p.stats.waiting().Inc()
defer p.stats.waiting().Dec()
getLoop:
for {
select {
case itm := <-p.queue:
p.stats.idle().Dec()
p.stats.updateReady()
if itm.Alive() {
p.stats.inUse().Inc()
p.stats.updateReady()
p.logger.Trace("item retrieved from pool", "id", itm.ID())
return itm
}
_ = itm.Close()
select {
case p.tokens <- struct{}{}:
case <-rCtx.Done():
break getLoop
}
case <-rCtx.Done():
break getLoop
}
}
return nil
}
func (p *Pool[PT, T]) Put(itm PT) {
p.stats.inUse().Dec()
defer p.stats.updateReady()
// check if alive
if itm.Alive() {
if !p.itemRecycling || !p.itemExpired(itm) {
p.stats.idle().Inc()
// alive and not expired
// push item back and finish iteration
p.queue <- itm // ignoring ctx.Done(), should never block here
p.logger.Trace("item returned to pool", "id", itm.ID())
return
}
}
p.logger.Trace("item recycled on returning", "id", itm.ID())
// recycle
_ = itm.Close()
// push token
p.tokens <- struct{}{} // ignoring ctx.Done(), should never block here
}
func (p *Pool[PT, T]) spawnItems(ctx context.Context) {
p.logger.Trace("pool spawner started")
defer func() {
p.wg.Done()
p.logger.Trace("pool spawner exited")
}()
spawnLoop:
for {
select {
case <-ctx.Done():
break spawnLoop
case <-p.tokens:
createLoop:
for {
p.wg.Add(1)
itm, err := p.spawnItem(ctx)
if err != nil {
if errors.Is(err, localErrs.LocalFailureError{}) {
// Local errors return instantly.
// Sleep here a bit to prevent unnecessary flood of create attempts.
time.Sleep(defaultCreateRetryDelayOnLocalErrors)
}
select {
case <-ctx.Done():
break spawnLoop
default:
continue createLoop
}
}
// Ignoring ctx.Done() here and put item in queue anyway,
// so it can be closed later by drain().
p.queue <- itm
p.stats.idle().Inc()
p.stats.updateReady()
break
}
}
}
}
func (p *Pool[PT, T]) drain() {
drainLoop:
for {
select {
case itm := <-p.queue:
p.stats.idle().Dec()
p.stats.updateReady()
_ = itm.Close()
default:
break drainLoop
}
}
}
func (p *Pool[PT, T]) spawnItem(ctx context.Context) (PT, error) {
defer p.wg.Done()
itm, err := p.createFunc(ctx, p.createTimeout)
if err != nil {
p.logger.Debug("pool item create error", "error", err)
return nil, err
}
if p.itemRecycling {
p.setItemExpire(itm.ID())
}
return itm, nil
}
func (p *Pool[PT, T]) setItemExpire(id uint64) {
p.itemsMx.Lock()
defer p.itemsMx.Unlock()
p.itemsExpire[id] = time.Now().Unix() + p.itemLifetime
}
func (p *Pool[PT, T]) getItemExpire(id uint64) int64 {
p.itemsMx.RLock()
defer p.itemsMx.RUnlock()
return p.itemsExpire[id]
}
func (p *Pool[PT, T]) itemExpired(itm PT) bool {
return p.getItemExpire(itm.ID())-p.recycleWindow+rand.Int63n(2*p.recycleWindow) < time.Now().Unix()
}
func (p *Pool[PT, T]) recycleItems(ctx context.Context) {
p.logger.Trace("pool recycler started")
defer func() {
p.wg.Done()
p.logger.Trace("pool recycler exited")
}()
ticker := time.NewTicker(defaultRecycleTick)
defer ticker.Stop()
recycleLoop:
for {
// wait for tick
select {
case <-ctx.Done():
break recycleLoop
case <-ticker.C:
}
// get item from queue
select {
case <-ctx.Done():
break recycleLoop
case itm := <-p.queue:
// check if alive
if itm.Alive() && !p.itemExpired(itm) {
// alive and not expired
// push item back and finish iteration
p.queue <- itm // ignoring ctx.Done(), should never block here
break
}
// recycle
p.stats.idle().Dec()
p.stats.updateReady()
_ = itm.Close()
p.logger.Trace("item recycled", "id", itm.ID())
// push token
p.tokens <- struct{}{} // ignoring ctx.Done(), should never block here
}
}
}
package pool
import (
s "github.com/adwski/ydb-go-query/internal/stats"
)
type (
stats struct {
inUse_ *s.Gauge
idle_ *s.Gauge
waiting_ *s.Gauge
ready_ *s.Indicator
}
)
func newStats(hi, lo int64) stats {
return stats{
inUse_: s.NewGauge(),
idle_: s.NewGauge(),
waiting_: s.NewGauge(),
ready_: s.NewIndicator(hi, lo),
}
}
func (s *stats) inUse() *s.Gauge {
return s.inUse_
}
func (s *stats) idle() *s.Gauge {
return s.idle_
}
func (s *stats) ready() *s.Indicator {
return s.ready_
}
func (s *stats) waiting() *s.Gauge {
return s.waiting_
}
func (s *stats) updateReady() {
s.ready_.Observe(s.idle_.Get() + s.inUse_.Get())
}
package query
import (
"context"
"errors"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/pool"
"github.com/adwski/ydb-go-query/internal/query/session"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
"google.golang.org/grpc"
)
var (
ErrNoSession = errors.New("no session")
ErrExec = errors.New("exec failed")
)
type (
Service struct {
qsc Ydb_Query_V1.QueryServiceClient
pool *pool.Pool[*session.Session, session.Session]
logger logger.Logger
}
)
type Config struct {
Transport grpc.ClientConnInterface
Logger logger.Logger
CreateTimeout time.Duration
PoolSize uint
PoolReadyThresholdHigh uint
PoolReadyThresholdLow uint
}
func NewService(runCtx context.Context, cfg Config) *Service {
qsc := Ydb_Query_V1.NewQueryServiceClient(cfg.Transport)
sessionPool := pool.New[*session.Session, session.Session](
runCtx,
pool.Config[*session.Session, session.Session]{
Logger: cfg.Logger,
CreateTimeout: cfg.CreateTimeout,
PoolSize: cfg.PoolSize,
ReadyThresholdPercentHigh: cfg.PoolReadyThresholdHigh,
ReadyThresholdPercentLow: cfg.PoolReadyThresholdLow,
CreateFunc: func(sessCtx context.Context, timeout time.Duration) (*session.Session, error) {
return session.CreateSession(sessCtx, qsc, cfg.Logger, timeout)
},
})
svc := &Service{
logger: cfg.Logger,
qsc: qsc,
pool: sessionPool,
}
return svc
}
func (svc *Service) Close() error {
return svc.pool.Close() //nolint:wrapcheck //unnecessary
}
func (svc *Service) AcquireSession(ctx context.Context) (*session.Session, func(), error) {
sess := svc.pool.Get(ctx)
if sess == nil {
return nil, nil, ErrNoSession
}
return sess, func() { svc.pool.Put(sess) }, nil
}
// Exec provides low-level single query execution.
func (svc *Service) Exec(
ctx context.Context,
query string,
params map[string]*Ydb.TypedValue,
txSettings *Ydb_Query.TransactionSettings,
) (Ydb_Query_V1.QueryService_ExecuteQueryClient, context.CancelFunc, error) {
sess, cleanup, err := svc.AcquireSession(ctx)
if err != nil {
return nil, nil, err
}
defer cleanup()
var txControl *Ydb_Query.TransactionControl
if txSettings != nil {
txControl = &Ydb_Query.TransactionControl{
TxSelector: &Ydb_Query.TransactionControl_BeginTx{
BeginTx: txSettings,
},
CommitTx: true,
}
}
stream, cancel, err := sess.Exec(ctx, query, params, txControl)
if err != nil {
return nil, nil, errors.Join(ErrExec, err)
}
return stream, cancel, nil
}
func (svc *Service) Ready() bool {
return svc.pool.Ready()
}
package session
import (
"context"
"errors"
"fmt"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
)
const (
defaultStatsMode = Ydb_Query.StatsMode_STATS_MODE_BASIC
defaultQuerySyntax = Ydb_Query.Syntax_SYNTAX_YQL_V1
defaultExecMode = Ydb_Query.ExecMode_EXEC_MODE_EXECUTE
)
var (
ErrExec = errors.New("exec error")
ErrTxRollback = errors.New("transaction rollback error")
ErrTxCommit = errors.New("transaction commit error")
ErrShutdown = errors.New("session is shut down")
)
func (s *Session) RollbackTX(ctx context.Context, txID string) error {
resp, err := s.qsc.RollbackTransaction(ctx, &Ydb_Query.RollbackTransactionRequest{
SessionId: s.id,
TxId: txID,
})
if err != nil {
return errors.Join(ErrTxRollback, err)
}
if resp.Status != Ydb.StatusIds_SUCCESS {
return errors.Join(ErrTxRollback, fmt.Errorf("status: %s", resp.Status.String()))
}
return nil
}
func (s *Session) CommitTX(ctx context.Context, txID string) error {
resp, err := s.qsc.CommitTransaction(ctx, &Ydb_Query.CommitTransactionRequest{
SessionId: s.id,
TxId: txID,
})
if err != nil {
return errors.Join(ErrTxCommit, err)
}
if resp.Status != Ydb.StatusIds_SUCCESS {
return errors.Join(ErrTxCommit, fmt.Errorf("status: %s", resp.Status.String()))
}
return nil
}
func (s *Session) Exec(
ctx context.Context,
query string,
params map[string]*Ydb.TypedValue,
txControl *Ydb_Query.TransactionControl,
) (Ydb_Query_V1.QueryService_ExecuteQueryClient, context.CancelFunc, error) {
if s.shutdown.Load() {
return nil, nil, ErrShutdown
}
streamCtx, cancelStream := context.WithCancel(ctx)
respExec, err := s.qsc.ExecuteQuery(streamCtx, &Ydb_Query.ExecuteQueryRequest{
SessionId: s.id,
ExecMode: defaultExecMode,
TxControl: txControl,
Query: &Ydb_Query.ExecuteQueryRequest_QueryContent{
QueryContent: &Ydb_Query.QueryContent{
Syntax: defaultQuerySyntax,
Text: query,
},
},
Parameters: params,
StatsMode: defaultStatsMode,
ConcurrentResultSets: false,
})
if err != nil {
cancelStream()
return nil, nil, errors.Join(ErrExec, err)
}
return respExec, cancelStream, nil
}
package session
import (
"context"
"errors"
"fmt"
"hash/maphash"
"io"
"sync"
"sync/atomic"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/xcontext"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
cleanupTimeout = 3 * time.Second
)
var (
ErrSessionCreate = errors.New("session create failed")
ErrSessionTransport = errors.New("session transport was not provided")
ErrSessionAttach = errors.New("session attach failed")
ErrSessionDelete = errors.New("session delete failed")
)
var (
hashSeed maphash.Seed
)
func init() {
hashSeed = maphash.MakeSeed()
}
type (
Session struct {
logger logger.Logger
transport grpc.ClientConnInterface
qsc Ydb_Query_V1.QueryServiceClient
stream Ydb_Query_V1.QueryService_AttachSessionClient
cancelFunc context.CancelFunc
done chan struct{}
state *Ydb_Query.SessionState
err error
id string
id_ uint64
node int64
shutdown atomic.Bool
}
Config struct {
Transport grpc.ClientConnInterface
CreateResponse *Ydb_Query.CreateSessionResponse
Logger logger.Logger
}
)
func CreateSession(
ctx context.Context,
qsc Ydb_Query_V1.QueryServiceClient,
logger logger.Logger,
timeout time.Duration,
) (*Session, error) {
var transport grpc.ClientConnInterface
sessCtx := xcontext.WithTransportPtr(ctx, &transport)
createCtx, cancel := context.WithTimeout(sessCtx, timeout)
defer cancel()
respCreate, err := qsc.CreateSession(createCtx, &Ydb_Query.CreateSessionRequest{})
if err != nil {
return nil, errors.Join(ErrSessionCreate, err)
}
if respCreate.Status != Ydb.StatusIds_SUCCESS {
return nil, errors.Join(ErrSessionCreate, fmt.Errorf("status: %s", respCreate.Status))
}
if transport == nil {
return nil, ErrSessionTransport
}
sess := &Session{
logger: logger,
transport: transport,
qsc: Ydb_Query_V1.NewQueryServiceClient(transport),
id: respCreate.GetSessionId(),
id_: maphash.String(hashSeed, respCreate.GetSessionId()),
node: respCreate.GetNodeId(),
done: make(chan struct{}),
}
if err = sess.attachStream(ctx); err != nil {
go func() { _ = sess.Close() }()
return nil, err
}
return sess, nil
}
func (s *Session) ID() uint64 {
return s.id_
}
func (s *Session) Alive() bool {
return !s.shutdown.Load()
}
func (s *Session) Close() error {
s.shutdown.Store(true)
if s.cancelFunc != nil {
// cancel stream
s.cancelFunc()
}
// ensure stream is canceled
<-s.done
// cleanup session
ctx, cancel := context.WithTimeout(context.Background(), cleanupTimeout)
defer cancel()
err := errors.Join(s.err, s.cleanup(ctx))
s.logger.Debug("session closed", "id", s.id)
return err
}
func (s *Session) attachStream(ctx context.Context) error {
attachCtx, streamCancel := context.WithCancel(ctx)
respAttach, err := s.qsc.AttachSession(attachCtx, &Ydb_Query.AttachSessionRequest{
SessionId: s.id,
})
if err != nil {
streamCancel()
s.err = err
close(s.done)
return errors.Join(ErrSessionAttach, err)
}
s.stream = respAttach
s.cancelFunc = streamCancel
s.logger.Trace("attached to session", "id", s.id, "node", s.node, "id_", s.id_)
sig := make(chan struct{}) // async success signal
go s.spin(sig)
// Looks like attach mechanism is non-blocking.
// AttachSession might finish but on YDB side
// session still may be not attached for some short time.
// Seems like transition to attached state is signaled by status:SUCCESS,
// so we need to wait for status change before handling session to the pool.
//
// Otherwise, we in race condition and first query for this session may return BAD REQUEST.
select {
case <-sig:
case <-s.done:
}
close(sig)
return nil
}
func (s *Session) spin(sigSuccess chan<- struct{}) {
once := sync.Once{}
for {
state, err := s.stream.Recv()
if err != nil {
switch {
case errors.Is(err, io.EOF):
s.logger.Debug("session stream ended", "id", s.id)
case status.Code(err) == codes.Canceled:
s.logger.Trace("session stream context canceled", "id", s.id)
default:
s.logger.Error("session stream error", "id", s.id, "err", err)
s.err = err
}
break
}
if s.state != state {
// TODO: Check state (which states can we expect here?)
s.logger.Debug("session state changed",
"id", s.id, "node", s.node, "state", state)
s.state = state
if state.Status == Ydb.StatusIds_SUCCESS {
once.Do(func() { sigSuccess <- struct{}{} })
}
}
}
s.shutdown.Store(true)
close(s.done)
}
func (s *Session) cleanup(ctx context.Context) error {
respDelete, err := s.qsc.DeleteSession(ctx, &Ydb_Query.DeleteSessionRequest{
SessionId: s.id,
})
if err != nil {
return errors.Join(ErrSessionDelete, err)
}
if respDelete.Status != Ydb.StatusIds_SUCCESS {
return errors.Join(ErrSessionDelete,
fmt.Errorf("status: %s", respDelete.Status))
}
return nil
}
package txsettings
import "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
func OnlineReadOnly() *Ydb_Query.TransactionSettings {
return &Ydb_Query.TransactionSettings{
TxMode: &Ydb_Query.TransactionSettings_OnlineReadOnly{
OnlineReadOnly: &Ydb_Query.OnlineModeSettings{},
},
}
}
func OnlineReadOnlyInconsistent() *Ydb_Query.TransactionSettings {
return &Ydb_Query.TransactionSettings{
TxMode: &Ydb_Query.TransactionSettings_OnlineReadOnly{
OnlineReadOnly: &Ydb_Query.OnlineModeSettings{
AllowInconsistentReads: true,
},
},
}
}
func SnapshotReadOnly() *Ydb_Query.TransactionSettings {
return &Ydb_Query.TransactionSettings{
TxMode: &Ydb_Query.TransactionSettings_SnapshotReadOnly{
SnapshotReadOnly: &Ydb_Query.SnapshotModeSettings{},
},
}
}
func StaleReadOnly() *Ydb_Query.TransactionSettings {
return &Ydb_Query.TransactionSettings{
TxMode: &Ydb_Query.TransactionSettings_StaleReadOnly{
StaleReadOnly: &Ydb_Query.StaleModeSettings{},
},
}
}
func SerializableReadWrite() *Ydb_Query.TransactionSettings {
return &Ydb_Query.TransactionSettings{
TxMode: &Ydb_Query.TransactionSettings_SerializableReadWrite{
SerializableReadWrite: &Ydb_Query.SerializableModeSettings{},
},
}
}
package stats
import "sync/atomic"
type Counter struct {
v *atomic.Uint64
}
func NewCounter() Counter {
return Counter{v: &atomic.Uint64{}}
}
func (c Counter) Inc() {
c.v.Add(1)
}
func (c Counter) Reset() {
c.v.Store(0)
}
func (c Counter) Get() uint64 {
return c.v.Load()
}
package stats
import "sync/atomic"
type Gauge struct {
v *atomic.Int64
}
func NewGauge() *Gauge {
return &Gauge{v: &atomic.Int64{}}
}
func (g Gauge) Inc() {
g.v.Add(1)
}
func (g Gauge) Dec() {
g.v.Add(-1)
}
func (g Gauge) Get() int64 {
return g.v.Load()
}
package stats
import (
"sync"
)
type Indicator struct {
mx sync.Mutex
v bool
thresholdHi int64
thresholdLo int64
}
func NewIndicator(hi, lo int64) *Indicator {
return &Indicator{
thresholdHi: hi,
thresholdLo: lo,
}
}
func (i *Indicator) Observe(val int64) {
i.mx.Lock()
defer i.mx.Unlock()
if i.v {
if val <= i.thresholdLo {
i.v = false
}
} else if val >= i.thresholdHi {
i.v = true
}
}
func (i *Indicator) Get() bool {
i.mx.Lock()
defer i.mx.Unlock()
return i.v
}
package auth
import (
"context"
"sync"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
)
const (
defaultTokenCallTimeout = 5 * time.Second
defaultTokenRenewFailInterval = 10 * time.Second
defaultTokenInitialRetry = time.Second
)
type (
Provider interface {
GetToken(ctx context.Context) (string, time.Time, error)
}
Auth struct {
logger logger.Logger
provider Provider
mx *sync.RWMutex
timer *time.Timer
expires time.Time
token string
renewDisable bool
}
Config struct {
Provider Provider
Logger logger.Logger
RenewDisable bool
}
)
func New(ctx context.Context, cfg Config) (*Auth, error) {
auth := &Auth{
provider: cfg.Provider,
logger: cfg.Logger,
renewDisable: cfg.RenewDisable,
mx: &sync.RWMutex{},
}
return auth, auth.mustGetToken(ctx)
}
func (a *Auth) GetToken() string {
a.mx.RLock()
defer a.mx.RUnlock()
return a.token
}
func (a *Auth) mustGetToken(ctx context.Context) (err error) {
getTokenLoop:
for {
select {
case <-ctx.Done():
return
default:
if err = a.getTokenTick(ctx); err == nil {
break getTokenLoop
}
time.Sleep(defaultTokenInitialRetry)
}
}
return //nolint:nilerr // unnecessary
}
func (a *Auth) getTokenTick(ctx context.Context) error {
ctxCall, cancel := context.WithTimeout(ctx, defaultTokenCallTimeout)
defer cancel()
token, expires, err := a.provider.GetToken(ctxCall)
if err != nil {
a.logger.Error("token error", "error", err)
a.setTimer(defaultTokenRenewFailInterval)
return err //nolint:wrapcheck //unnecessary
}
a.mx.Lock()
a.token = token
a.expires = expires
a.mx.Unlock()
renew := a.expires.Sub(time.Now().UTC()) / 2
a.setTimer(renew)
a.logger.Info("token retrieved successfully",
"expiresAt", a.expires.Format(time.RFC3339),
"renewIn", renew.Truncate(time.Second))
return nil
}
func (a *Auth) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
if a.renewDisable {
return
}
a.logger.Debug("auth token renew started")
defer func() {
a.timer.Stop()
a.logger.Debug("auth token renew stopped")
}()
renewLoop:
for {
select {
case <-ctx.Done():
break renewLoop
case <-a.timer.C:
_ = a.getTokenTick(ctx)
}
}
}
func (a *Auth) setTimer(dur time.Duration) {
if a.timer == nil {
a.timer = time.NewTimer(dur)
} else {
a.timer.Reset(dur)
}
}
package userpass
import (
"context"
"errors"
"time"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Auth_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Auth"
"google.golang.org/grpc"
)
const (
defaultYDBAuthExpire = 12 * time.Hour
)
var (
ErrLogin = errors.New("login request failed")
ErrNilOperation = errors.New("nil operation")
ErrUnauthorized = errors.New("unauthorized")
ErrLoginUnmarshall = errors.New("login response unmarshall failed")
)
type (
UserPass struct {
authSvc Ydb_Auth_V1.AuthServiceClient
user string
pass string
}
Config struct {
Transport grpc.ClientConnInterface
Username string
Password string
}
)
func New(cfg Config) *UserPass {
auth := &UserPass{
authSvc: Ydb_Auth_V1.NewAuthServiceClient(cfg.Transport),
user: cfg.Username,
pass: cfg.Password,
}
return auth
}
func (up *UserPass) GetToken(ctx context.Context) (token string, expires time.Time, err error) {
resp, err := up.authSvc.Login(ctx, &Ydb_Auth.LoginRequest{
User: up.user,
Password: up.pass,
})
if err != nil {
err = errors.Join(ErrLogin, err)
return
}
op := resp.GetOperation()
if op == nil {
err = errors.Join(ErrLogin, ErrNilOperation)
return
}
if op.GetStatus() == Ydb.StatusIds_UNAUTHORIZED {
err = ErrUnauthorized
return
}
var result Ydb_Auth.LoginResult
if err = op.GetResult().UnmarshalTo(&result); err != nil {
err = errors.Join(ErrLoginUnmarshall, err)
return
}
token = result.Token
expires = time.Now().Add(defaultYDBAuthExpire)
return
}
package yc
import (
"context"
"errors"
"time"
ycsdk "github.com/yandex-cloud/go-sdk"
"github.com/yandex-cloud/go-sdk/iamkey"
)
var (
ErrIAMKeyUnspecified = errors.New("either IAMKey or IAMKeyFile must be provided")
ErrIAMKey = errors.New("error reading IAM key")
ErrIAMKeyFile = errors.New("error reading IAM key file")
ErrIAMTokenCreate = errors.New("error creating IAM token")
ErrServiceAccountCreds = errors.New("error reading service account credentials")
ErrYCSDK = errors.New("error reading YC SDK client")
)
type (
// YC is Yandex Cloud authenticator.
YC struct {
sdk *ycsdk.SDK
}
// Config is Yandex CLoud authenticator config.
Config struct {
// IamKeyFile specifies file path to IAM key file in jsom format.
IamKeyFile string
// IamKey specifies service account IAM key (usually in json format).
// This param (if not empty) takes precedence over IamKeyFile.
IamKey []byte
}
)
func New(ctx context.Context, cfg Config) (*YC, error) {
if cfg.IamKeyFile == "" && len(cfg.IamKey) == 0 {
return nil, ErrIAMKeyUnspecified
}
var (
key *iamkey.Key
err error
)
if len(cfg.IamKey) > 0 {
if key, err = iamkey.ReadFromJSONBytes(cfg.IamKey); err != nil {
return nil, errors.Join(ErrIAMKey, err)
}
} else {
if key, err = iamkey.ReadFromJSONFile(cfg.IamKeyFile); err != nil {
return nil, errors.Join(ErrIAMKeyFile, err)
}
}
creds, err := ycsdk.ServiceAccountKey(key)
if err != nil {
return nil, errors.Join(ErrServiceAccountCreds, err)
}
yc, err := ycsdk.Build(ctx, ycsdk.Config{ // seems like Build() doesn't use context
Credentials: creds,
})
if err != nil {
return nil, errors.Join(ErrYCSDK, err)
}
return &YC{sdk: yc}, nil
}
func (a *YC) GetToken(ctx context.Context) (token string, expires time.Time, err error) {
tokenResp, err := a.sdk.CreateIAMToken(ctx)
if err != nil {
err = errors.Join(ErrIAMTokenCreate, err)
return
}
token = tokenResp.GetIamToken()
expires = tokenResp.GetExpiresAt().AsTime()
return
}
package v3
import (
"errors"
"sync"
)
const (
nodesPrealloc = 16
defaultLocation = "&&def"
)
var (
ErrConnCreate = errors.New("connection create failed")
ErrUnknownLocation = errors.New("unknown location")
ErrNoSuchID = errors.New("no such id")
)
type (
connection[T any] interface {
*T
Alive() bool
Close() error
ID() uint64
}
// Grid is fixed-level load balancer that is specifically
// balances between connections grouped by locations.
//
Grid[PT connection[T], T any] struct {
mx *sync.Mutex
locDta map[string][]PT // connections per location
connIdxs map[string]int // next available indexes per location
blnIdxs map[string]int // round-robin indexes per location
locPrefM map[string]struct{} // locations with configured preference
locPref []string // ordered preference for locations
conns int // amount of conns to spawn for each endpoint
ignoreLocations bool // use default locations for all endpoints
}
// Config provides initial params for Grid.
Config struct {
// LocationPreference defines location processing sequence.
// See GetConn().
LocationPreference []string
// ConnsPerEndpoint specifies how many individual connections will be
// spawned during Add() call.
ConnsPerEndpoint int
// IgnoreLocations explicitly sets to ignore LocationPreference and
// use common default location for all endpoints.
// If LocationPreference is empty, default location is used regardless
// of this flag's value.
IgnoreLocations bool
}
createFunc[PT connection[T], T any] func() (PT, error)
)
// NewGrid creates new grid load balancer.
func NewGrid[PT connection[T], T any](cfg Config) *Grid[PT, T] {
dta := make(map[string][]PT)
locPrefM := make(map[string]struct{})
for _, location := range cfg.LocationPreference {
dta[location] = make([]PT, 0, cfg.ConnsPerEndpoint*nodesPrealloc)
locPrefM[location] = struct{}{}
}
if len(cfg.LocationPreference) == 0 {
dta[defaultLocation] = make([]PT, 0, cfg.ConnsPerEndpoint*nodesPrealloc)
cfg.IgnoreLocations = true
}
grid := &Grid[PT, T]{
mx: &sync.Mutex{},
locDta: dta,
locPrefM: locPrefM,
connIdxs: make(map[string]int),
blnIdxs: make(map[string]int),
locPref: cfg.LocationPreference,
conns: cfg.ConnsPerEndpoint,
ignoreLocations: cfg.IgnoreLocations,
}
return grid
}
// GetConn selects balanced connection based on available
// locations and alive connections.
// - It will always return connections from first location
// in LocationPreference list.
// - If there's no alive connections in current location,
// next location from the list is used (and so on).
// - If end of list is reached and there's still alive endpoints
// from locations that are not in this list, GetConn() will select
// them in no particular location-order.
// - If IgnoreLocations is set, it uses only default location.
//
// Within one location connections are selected using round-robin approach.
func (g *Grid[PT, T]) GetConn() PT {
g.mx.Lock()
defer g.mx.Unlock()
if g.ignoreLocations {
return g.lookupInLocation(defaultLocation)
}
// lookup in available locations according to preference
for _, loc := range g.locPref {
if _, ok := g.locDta[loc]; ok {
conn := g.lookupInLocation(loc)
if conn != nil {
return conn
}
}
}
// If some locations are not in preference,
// lookup inside them as well.
if len(g.locDta) > len(g.locPref) {
for loc := range g.locDta {
if _, ok := g.locPrefM[loc]; !ok {
conn := g.lookupInLocation(loc)
if conn != nil {
return conn
}
}
}
}
return nil
}
func (g *Grid[PT, T]) lookupInLocation(location string) PT {
var (
idx = g.blnIdxs[location]
nodes = g.locDta[location]
ln = len(nodes)
)
// Get next alive conn,
// making full circle in worst case.
for i := idx; i < idx+ln; i++ {
conn := nodes[i%ln]
if conn != nil && conn.Alive() {
g.blnIdxs[location] = (i + 1) % ln
return conn
}
}
return nil
}
// Add creates connections in specified location.
func (g *Grid[PT, T]) Add(location string, creatF createFunc[PT, T]) error {
g.mx.Lock()
defer g.mx.Unlock()
// Spawn connections
newConns := make([]PT, 0, g.conns)
for range g.conns {
conn, err := creatF()
if err != nil {
for _, conn_ := range newConns {
_ = conn_.Close()
}
return errors.Join(ErrConnCreate, err)
}
newConns = append(newConns, conn)
}
if g.ignoreLocations {
location = defaultLocation
}
nodes, ok := g.locDta[location]
if !ok {
// First time seeing this location
newLoc := make([]PT, g.conns, g.conns*nodesPrealloc)
copy(newLoc, newConns)
g.locDta[location] = newLoc
g.connIdxs[location] = g.conns
return nil
}
// get next available index
idx := g.connIdxs[location]
// add new connections to location
for _, conn := range newConns {
for {
if idx == len(nodes) {
nodes = append(nodes, conn)
g.locDta[location] = nodes
idx++
break
} else if nodes[idx] == nil {
nodes[idx] = conn
idx++
break
}
idx++
}
}
g.connIdxs[location] = idx
return nil
}
// Delete deletes connections from location.
// It uses linear search within location to find all matching connections.
// 'Slots' from deleted connections can be reused later by Add().
func (g *Grid[PT, T]) Delete(location string, id uint64) error {
g.mx.Lock()
defer g.mx.Unlock()
if g.ignoreLocations {
location = defaultLocation
}
nodes, ok := g.locDta[location]
if !ok {
return ErrUnknownLocation
}
deleted := false
// Search for connections with given ID and delete them from location.
for idx, nd := range nodes {
if nd != nil && nd.ID() == id {
if idx < g.connIdxs[location] {
// update next available index
g.connIdxs[location] = idx
}
deleted = true
_ = nd.Close()
nodes[idx] = nil
}
}
if !deleted {
return ErrNoSuchID
}
return nil
}
package v3
import (
"errors"
"sync"
)
const (
defaultLocation = "&&def"
minConnectionsPerEndpoint = 1
)
var (
ErrConnCreate = errors.New("connection create failed")
ErrUnknownLocation = errors.New("unknown location")
ErrNoSuchID = errors.New("no such id")
ErrEmptyLocation = errors.New("empty location")
)
type (
connection[T any] interface {
*T
Alive() bool
Close() error
ID() uint64
}
node[PT connection[T], T any] struct {
next *node[PT, T]
conn PT
}
locationData[PT connection[T], T any] struct {
lookupPtr *node[PT, T] // points to next balancing decision
insertPtr *node[PT, T] // points to insertion point for new connections
size int // amount of connections inside current location
}
// Grid is fixed-level load balancer that is specifically
// balances between connections grouped by locations.
//
Grid[PT connection[T], T any] struct {
mx *sync.Mutex
locDta map[string]locationData[PT, T] // connections per location
locPrefM map[string]struct{} // locations with configured preference
locPref []string // ordered preference for locations
connsPerEndpoint int // amount of connsPerEndpoint to spawn for each endpoint
ignoreLocations bool // use default locations for all endpoints
}
// Config provides initial params for Grid.
Config struct {
// LocationPreference defines location processing sequence.
// See GetConn().
LocationPreference []string
// ConnsPerEndpoint specifies how many individual connections will be
// spawned during Add() call.
ConnsPerEndpoint int
// IgnoreLocations explicitly sets to ignore LocationPreference and
// use common default location for all endpoints.
// If LocationPreference is empty, default location is used regardless
// of this flag's value.
IgnoreLocations bool
}
createFunc[PT connection[T], T any] func() (PT, error)
)
// NewGrid creates new grid load balancer.
func NewGrid[PT connection[T], T any](cfg Config) *Grid[PT, T] {
if len(cfg.LocationPreference) == 0 {
cfg.IgnoreLocations = true
}
if cfg.ConnsPerEndpoint < 1 {
cfg.ConnsPerEndpoint = minConnectionsPerEndpoint
}
grid := &Grid[PT, T]{
mx: &sync.Mutex{},
locDta: make(map[string]locationData[PT, T]),
locPrefM: make(map[string]struct{}),
locPref: cfg.LocationPreference,
connsPerEndpoint: cfg.ConnsPerEndpoint,
ignoreLocations: cfg.IgnoreLocations,
}
for _, location := range cfg.LocationPreference {
grid.locPrefM[location] = struct{}{}
}
return grid
}
// GetConn selects balanced connection based on available
// locations and alive connections.
// - It will always return connections from first location
// in LocationPreference list.
// - If there's no alive connections in first location,
// next location from the list is used (and so on).
// - If there's no alive connections in any of preferred locations,
// other existing location will be checked in no particular order.
// - If IgnoreLocations is set, it uses only default location.
//
// Within one location connections are selected using round-robin approach.
func (g *Grid[PT, T]) GetConn() PT {
g.mx.Lock()
defer g.mx.Unlock()
if g.ignoreLocations {
return g.lookupInLocation(defaultLocation)
}
// lookup in available locations according to preference
for _, loc := range g.locPref {
if _, ok := g.locDta[loc]; ok {
conn := g.lookupInLocation(loc)
if conn != nil {
return conn
}
}
}
// If some locations are not in preference,
// lookup inside them as well.
for loc := range g.locDta {
if _, ok := g.locPrefM[loc]; !ok {
conn := g.lookupInLocation(loc)
if conn != nil {
return conn
}
}
}
return nil
}
func (g *Grid[PT, T]) lookupInLocation(location string) PT {
var (
loc = g.locDta[location]
ptr = loc.lookupPtr
size = loc.size
)
// Get next alive conn,
// making full circle in worst case.
for ; size > 0; size-- {
if ptr.conn.Alive() {
loc.lookupPtr = ptr.next
g.locDta[location] = loc
return ptr.conn
}
ptr = ptr.next
}
return nil
}
// Add creates connections in specified location.
func (g *Grid[PT, T]) Add(location string, creatF createFunc[PT, T]) error {
// Spawn connections
var (
head, prev *node[PT, T]
)
for range g.connsPerEndpoint {
conn, err := creatF()
if err != nil {
for ; head != nil; head = head.next {
_ = head.conn.Close()
}
return errors.Join(ErrConnCreate, err)
}
if prev == nil {
prev = &node[PT, T]{conn: conn}
head = prev
} else {
prev.next = &node[PT, T]{conn: conn}
prev = prev.next
}
}
// Attach connections to location
g.mx.Lock()
defer g.mx.Unlock()
if g.ignoreLocations {
location = defaultLocation
}
locDta, ok := g.locDta[location]
if ok && locDta.size != 0 {
// insert conn list into location list
locDta.insertPtr.next, prev.next = head, locDta.insertPtr.next
locDta.size += g.connsPerEndpoint
g.locDta[location] = locDta
return nil
}
// First time seeing this location
prev.next = head // cycle nodes
g.locDta[location] = locationData[PT, T]{
lookupPtr: head,
insertPtr: prev,
size: g.connsPerEndpoint,
}
return nil
}
// Delete deletes connections from location.
// It uses linear search within location to find all matching connections.
func (g *Grid[PT, T]) Delete(location string, id uint64) error {
g.mx.Lock()
defer g.mx.Unlock()
if g.ignoreLocations {
location = defaultLocation
}
locDta, ok := g.locDta[location]
switch {
case !ok:
return ErrUnknownLocation
case locDta.size == 0:
return ErrEmptyLocation
case locDta.size == g.connsPerEndpoint:
// We have connsPerEndpoint of only one endpoint.
if locDta.insertPtr.conn.ID() == id {
// delete last remaining endpoint
delete(g.locDta, location)
return nil
}
return ErrNoSuchID
}
var (
start *node[PT, T]
// prev and ptr are starting at border between some endpoints
ptr = locDta.insertPtr.next
prev = locDta.insertPtr
)
// New conns for the same endpoint are always added
// as continuous range. To delete conns by endpoint id
// we need to find conn (start) preceding first conn of this endpoint
// then find first conn of next endpoint.
// And finally point start.next to conn of next endpoint.
for size := locDta.size; size >= 0; size-- {
if ptr.conn.ID() == id {
// found first conn in deletion range
start = prev
// scroll to conn of next endpoint
for ctr := 0; ctr < g.connsPerEndpoint; ctr++ {
ptr = ptr.next
}
// found last conn
start.next = ptr
// Warp lookup and insert pointers
// if they are in deleted range.
if locDta.insertPtr.conn.ID() == id {
locDta.insertPtr = ptr
}
if locDta.lookupPtr.conn.ID() == id {
locDta.lookupPtr = ptr
}
g.locDta[location] = locDta
return nil
}
prev, ptr = ptr, ptr.next
}
return ErrNoSuchID
}
package transport
import (
"context"
"errors"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
const (
headerAuth = "x-ydb-auth-ticket"
headerDatabase = "x-ydb-database"
)
var (
ErrNoToken = errors.New("authenticator did not provide token")
ErrDial = errors.New("connection dial error")
ErrClose = errors.New("connection close error")
)
type (
Authenticator interface {
GetToken() string
}
Connection struct {
*grpc.ClientConn
auth Authenticator
db string
endpointID uint64
}
)
func NewConnection(
ctx context.Context,
endpoint string,
creds credentials.TransportCredentials,
auth Authenticator,
db string,
endpointID uint64,
) (*Connection, error) {
var opts []grpc.DialOption
opts = append(opts,
grpc.WithTransportCredentials(creds))
grpcConn, err := grpc.DialContext(ctx, endpoint, opts...)
if err != nil {
return nil, errors.Join(ErrDial, err)
}
return &Connection{
ClientConn: grpcConn,
auth: auth,
db: db,
endpointID: endpointID,
}, nil
}
func (c *Connection) ID() uint64 {
return c.endpointID
}
func (c *Connection) Close() error {
if err := c.ClientConn.Close(); err != nil {
return errors.Join(ErrClose, err)
}
return nil
}
func (c *Connection) Alive() bool {
switch c.GetState() {
case connectivity.Ready, connectivity.Idle:
return true
default:
return false
}
}
func setContextMD(ctx context.Context, auth Authenticator, db string) (context.Context, error) {
md := metadata.Pairs(headerDatabase, db)
if auth != nil {
token := auth.GetToken()
if token == "" {
return nil, ErrNoToken
}
md.Append(headerAuth, token)
}
return metadata.NewOutgoingContext(ctx, md), nil
}
func (c *Connection) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
callCtx, err := setContextMD(ctx, c.auth, c.db)
if err != nil {
return err
}
return c.ClientConn.Invoke(callCtx, method, args, reply, opts...) //nolint:wrapcheck // unnecessary
}
func (c *Connection) NewStream(
ctx context.Context,
desc *grpc.StreamDesc,
method string,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
callCtx, err := setContextMD(ctx, c.auth, c.db)
if err != nil {
return nil, err
}
return c.ClientConn.NewStream(callCtx, desc, method, opts...) //nolint:wrapcheck // unnecessary
}
package credentials
import (
"crypto/tls"
"crypto/x509"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
func Insecure() credentials.TransportCredentials {
return insecure.NewCredentials()
}
func TLS() credentials.TransportCredentials {
return credentials.NewTLS(tlsConfig())
}
func TLSSkipVerify() credentials.TransportCredentials {
tlsCfg := tlsConfig()
tlsCfg.InsecureSkipVerify = true
return credentials.NewTLS(tlsCfg)
}
func tlsConfig() *tls.Config {
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: x509.NewCertPool(),
}
if sysPool, err := x509.SystemCertPool(); err == nil {
tlsCfg.RootCAs = sysPool
}
return tlsCfg
}
package dispatcher
import (
"context"
"strconv"
"sync"
"github.com/adwski/ydb-go-query/internal/endpoints"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/transport"
balancing "github.com/adwski/ydb-go-query/internal/transport/balancing/v4"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
type (
EndpointsProvider interface {
EndpointsChan() <-chan endpoints.Announce
}
// Dynamic is a dispatcher that provides dynamic transport layer by
// gluing together endpoints provider and balancer.
//
// It uses provider's announces to populate balancing grid,
// and acts as grpc transport.
Dynamic struct {
logger logger.Logger
balancer *balancing.Grid[*transport.Connection, transport.Connection]
invoker *invoker
discovery EndpointsProvider
transportCredentials credentials.TransportCredentials
auth transport.Authenticator
db string
}
Config struct {
Logger logger.Logger
EndpointsProvider EndpointsProvider
Auth transport.Authenticator
TransportCredentials credentials.TransportCredentials
DB string
InitNodes []string
Balancing balancing.Config
}
)
func NewDynamic(cfg Config) *Dynamic {
grid := balancing.NewGrid[*transport.Connection, transport.Connection](cfg.Balancing)
return &Dynamic{
logger: cfg.Logger,
discovery: cfg.EndpointsProvider,
transportCredentials: cfg.TransportCredentials,
auth: cfg.Auth,
db: cfg.DB,
invoker: newInvoker(grid),
balancer: grid,
}
}
func (d *Dynamic) Run(ctx context.Context, wg *sync.WaitGroup) {
d.logger.Debug("dispatcher started")
defer func() {
d.logger.Debug("dispatcher stopped")
wg.Done()
}()
runLoop:
for {
select {
case <-ctx.Done():
break runLoop
case ann := <-d.discovery.EndpointsChan():
d.processAnnounce(ctx, ann)
}
}
}
func (d *Dynamic) Transport() grpc.ClientConnInterface {
return d.invoker
}
func (d *Dynamic) processAnnounce(ctx context.Context, ann endpoints.Announce) {
for epAdd := range ann.Add {
addr := endpointFullAddress(&epAdd)
if err := d.balancer.Add(epAdd.Location, func() (*transport.Connection, error) {
return transport.NewConnection(ctx, addr, d.transportCredentials, d.auth, d.db, epAdd.AddressHash)
}); err != nil {
d.logger.Error("unable to add endpoint", "error", err)
} else {
d.logger.Debug("endpoint added", "address", addr)
}
}
for _, epDel := range ann.Del {
addr := endpointFullAddress(&epDel)
if err := d.balancer.Delete(epDel.Location, epDel.AddressHash); err != nil {
d.logger.Error("unable to delete endpoint", "error", err)
} else {
d.logger.Debug("endpoint deleted", "address", addr)
}
}
}
type addrPort interface {
GetAddress() string
GetPort() uint32
}
func endpointFullAddress(ep addrPort) string {
return ep.GetAddress() + ":" + strconv.Itoa(int(ep.GetPort()))
}
package dispatcher
import (
"context"
"errors"
localErrs "github.com/adwski/ydb-go-query/internal/errors"
"github.com/adwski/ydb-go-query/internal/transport"
balancing "github.com/adwski/ydb-go-query/internal/transport/balancing/v4"
"github.com/adwski/ydb-go-query/internal/xcontext"
"google.golang.org/grpc"
)
type invoker struct {
*balancing.Grid[*transport.Connection, transport.Connection]
}
func newInvoker(grid *balancing.Grid[*transport.Connection, transport.Connection]) *invoker {
return &invoker{Grid: grid}
}
func (inv *invoker) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
trPtr := xcontext.GetTransportPtr(ctx)
if conn := inv.GetConn(); conn != nil {
if trPtr != nil {
*trPtr = conn
}
return conn.Invoke(ctx, method, args, reply, opts...) //nolint:wrapcheck // unnecessary
}
return errors.Join(localErrs.LocalFailureError{}, ErrNoConnections)
}
func (inv *invoker) NewStream(
ctx context.Context,
desc *grpc.StreamDesc,
method string,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
if conn := inv.GetConn(); conn != nil {
return conn.NewStream(ctx, desc, method, opts...) //nolint:wrapcheck // unnecessary
}
return nil, errors.Join(localErrs.LocalFailureError{}, ErrNoConnections)
}
package dispatcher
import (
"context"
"errors"
"github.com/adwski/ydb-go-query/internal/transport"
balancing "github.com/adwski/ydb-go-query/internal/transport/balancing/v4"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
type Static struct {
invoker *invoker
}
// NewStatic provides transport layer with fixed endpoints.
func NewStatic(
ctx context.Context,
endpoints []string,
creds credentials.TransportCredentials,
auth transport.Authenticator,
db string,
) (*Static, error) {
grid := balancing.NewGrid[*transport.Connection, transport.Connection](balancing.Config{
ConnsPerEndpoint: 1,
IgnoreLocations: true,
})
for _, addr := range endpoints {
err := grid.Add("", func() (*transport.Connection, error) {
return transport.NewConnection(ctx, addr, creds, auth, db, 0)
})
if err != nil {
return nil, errors.Join(ErrCreateEndpoint, err)
}
}
return &Static{invoker: newInvoker(grid)}, nil
}
func (s *Static) Transport() grpc.ClientConnInterface {
return s.invoker
}
package xcontext
import (
"context"
"google.golang.org/grpc"
)
type (
transportPtr struct{}
)
func WithTransportPtr(ctx context.Context, epPtr *grpc.ClientConnInterface) context.Context {
return context.WithValue(ctx, transportPtr{}, epPtr)
}
func GetTransportPtr(ctx context.Context) *grpc.ClientConnInterface {
trPtr, ok := ctx.Value(transportPtr{}).(*grpc.ClientConnInterface)
if ok {
return trPtr
}
return nil
}
package query
import (
"context"
"errors"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/query"
"github.com/adwski/ydb-go-query/internal/query/txsettings"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
)
const (
maxQueryLogLength = 1000
)
var (
queryLogCut = []byte("...")
)
type Ctx struct {
qSvc *query.Service
txSet *Ydb_Query.TransactionSettings
logger logger.Logger
timeout time.Duration
}
func NewCtx(
logger logger.Logger,
qSvc *query.Service,
txSet *Ydb_Query.TransactionSettings,
timeout time.Duration,
) *Ctx {
return &Ctx{
logger: logger,
qSvc: qSvc,
txSet: txSet,
timeout: timeout,
}
}
func (qc *Ctx) OnlineReadOnly() *Ctx {
newQCtx := *qc
newQCtx.txSet = txsettings.OnlineReadOnly()
return &newQCtx
}
func (qc *Ctx) OnlineReadOnlyInconsistent() *Ctx {
newQCtx := *qc
newQCtx.txSet = txsettings.OnlineReadOnlyInconsistent()
return &newQCtx
}
func (qc *Ctx) SnapshotReadOnly() *Ctx {
newQCtx := *qc
newQCtx.txSet = txsettings.SnapshotReadOnly()
return &newQCtx
}
func (qc *Ctx) StaleReadOnly() *Ctx {
newQCtx := *qc
newQCtx.txSet = txsettings.StaleReadOnly()
return &newQCtx
}
func (qc *Ctx) SerializableReadWrite() *Ctx {
newQCtx := *qc
newQCtx.txSet = txsettings.SerializableReadWrite()
return &newQCtx
}
func (qc *Ctx) Query(queryContent string) *Query {
return newQuery(
queryContent,
func(
ctx context.Context,
queryContent string,
params map[string]*Ydb.TypedValue,
collectRows func([]*Ydb.Value) error,
timeout time.Duration,
) (*Result, error) {
return qc.exec(ctx, queryContent, params, collectRows, qc.txSet, timeout)
},
)
}
func (qc *Ctx) Exec(ctx context.Context, queryContent string) (*Result, error) {
return qc.exec(ctx, queryContent, nil, nil, nil, 0)
}
func (qc *Ctx) exec(
ctx context.Context,
queryContent string,
params map[string]*Ydb.TypedValue,
collectRows func([]*Ydb.Value) error,
txSet *Ydb_Query.TransactionSettings,
timeout time.Duration,
) (*Result, error) {
var (
qCancel context.CancelFunc
)
if timeout == 0 {
timeout = qc.timeout
}
if timeout > 0 {
ctx, qCancel = context.WithDeadline(ctx, time.Now().Add(timeout))
defer qCancel()
}
stream, cancel, err := qc.qSvc.Exec(ctx, queryContent, params, txSet)
if err != nil {
return nil, err //nolint:wrapcheck //unnecessary
}
qc.logger.TraceFunc(func() (string, []any) {
return "received result stream", []any{"query", strip(queryContent)}
})
return qc.processResult(stream, cancel, collectRows)
}
func (qc *Ctx) Tx(ctx context.Context) (*Transaction, error) {
sess, cleanup, err := qc.qSvc.AcquireSession(ctx)
if err != nil {
return nil, err //nolint:wrapcheck //unnecessary
}
tx := &Transaction{
logger: qc.logger,
settings: qc.txSet,
sess: sess,
cleanup: cleanup,
}
return tx, nil
}
func (qc *Ctx) processResult(
stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
cancel context.CancelFunc,
collectRows func([]*Ydb.Value) error,
) (*Result, error) {
res := newResult(stream, cancel, qc.logger, collectRows)
if err := res.recv(); err != nil {
return nil, errors.Join(ErrResult, err)
}
return res, nil
}
func strip(s string) string {
if len(s) > maxQueryLogLength {
b := []byte(s[:maxQueryLogLength-2])
return string(append(b, queryLogCut...))
}
return s
}
package query
import (
"context"
"time"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
)
type (
execFunc func(
context.Context,
string,
map[string]*Ydb.TypedValue,
func([]*Ydb.Value) error,
time.Duration,
) (*Result, error)
Query struct {
collectRowsFunc func([]*Ydb.Value) error
execFunc execFunc
params map[string]*Ydb.TypedValue
content string
timeout time.Duration
}
)
func newQuery(content string, eF execFunc) *Query {
return &Query{
content: content,
execFunc: eF,
}
}
func (q *Query) Params(params map[string]*Ydb.TypedValue) *Query {
q.params = params
return q
}
func (q *Query) Param(name string, val *Ydb.TypedValue) *Query {
if q.params == nil {
q.params = make(map[string]*Ydb.TypedValue)
}
q.params[name] = val
return q
}
func (q *Query) Collect(collectRowsFunc func([]*Ydb.Value) error) *Query {
q.collectRowsFunc = collectRowsFunc
return q
}
func (q *Query) Timeout(timeout time.Duration) *Query {
q.timeout = timeout
return q
}
func (q *Query) Exec(ctx context.Context) (*Result, error) {
return q.execFunc(ctx, q.content, q.params, q.collectRowsFunc, q.timeout)
}
package query
import (
"context"
"errors"
"fmt"
"sync/atomic"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Issue"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats"
)
var (
ErrPartStatus = errors.New("result part status error")
ErrStream = errors.New("result stream error")
ErrIssues = errors.New("query result has issues")
)
type Result struct {
logger logger.Logger
cancel context.CancelFunc
stream Ydb_Query_V1.QueryService_ExecuteQueryClient
stats *Ydb_TableStats.QueryStats
err error
txID string
collectRowsFunc func([]*Ydb.Value) error
issues []*Ydb_Issue.IssueMessage
cols []*Ydb.Column
rows []*Ydb.Value
done atomic.Bool
}
func newResult(
stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
cancel context.CancelFunc,
logger logger.Logger,
collectRowsFunc func([]*Ydb.Value) error,
) *Result {
return &Result{
logger: logger,
stream: stream,
cancel: cancel,
collectRowsFunc: collectRowsFunc,
}
}
// close closes result stream,
// result data remains available.
func (r *Result) close() {
r.cancel()
r.done.Store(true)
}
func (r *Result) Err() error {
return r.err
}
func (r *Result) Issues() []*Ydb_Issue.IssueMessage { return r.issues }
func (r *Result) Cols() []*Ydb.Column {
return r.cols
}
func (r *Result) Rows() []*Ydb.Value {
return r.rows
}
func (r *Result) Stats() *Ydb_TableStats.QueryStats {
return r.stats
}
func (r *Result) TxID() string {
return r.txID
}
// recv reads all parts from result stream till completion.
// It assumes that parts are arriving sequentially,
// i.e. ConcurrentResultSets is false.
func (r *Result) recv() error {
if r.done.Load() {
return nil
}
for {
part, err := r.stream.Recv()
r.logger.Trace("received result part", "part", part, "error", err)
if err != nil {
return errors.Join(ErrStream, err)
}
r.issues = append(r.issues, part.Issues...)
if part.Status != Ydb.StatusIds_SUCCESS {
r.err = errors.Join(ErrPartStatus, fmt.Errorf("status: %s", part.Status))
break
}
if part.TxMeta != nil {
r.txID = part.TxMeta.Id
}
if len(part.Issues) > 0 {
r.err = errors.Join(ErrIssues, r.err)
}
if part.ResultSet != nil {
if r.cols == nil && len(part.ResultSet.Columns) > 0 {
r.cols = part.ResultSet.Columns
}
if len(part.ResultSet.Rows) > 0 {
if r.collectRowsFunc != nil {
err = r.collectRowsFunc(part.ResultSet.Rows)
if err != nil {
r.err = errors.Join(err, r.err)
break
}
} else {
r.rows = append(r.rows, part.ResultSet.Rows...)
}
}
}
if part.ExecStats != nil {
// stats on the last part
// TODO: find better way to detect last part
r.stats = part.ExecStats
break
}
}
r.close()
return nil
}
package query
import (
"context"
"errors"
"time"
"github.com/adwski/ydb-go-query/internal/logger"
"github.com/adwski/ydb-go-query/internal/query/session"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
)
var (
ErrTxFinished = errors.New("transaction already finished")
)
type (
Transaction struct {
logger logger.Logger
sess *session.Session
cleanup func()
settings *Ydb_Query.TransactionSettings
id string
finish bool // committed or rolled back
}
)
func (tx *Transaction) Rollback(ctx context.Context) error {
if tx.finish {
return ErrTxFinished
}
if err := tx.sess.RollbackTX(ctx, tx.id); err != nil {
return err //nolint:wrapcheck // unnecessary
}
tx.finish = true
tx.cleanup()
return nil
}
func (tx *Transaction) Commit(ctx context.Context) error {
if tx.finish {
return ErrTxFinished
}
if err := tx.sess.CommitTX(ctx, tx.id); err != nil {
return err //nolint:wrapcheck // unnecessary
}
tx.finish = true
tx.cleanup()
return nil
}
func (tx *Transaction) Query(queryContent string) *TxQuery {
return newTxQuery(
queryContent,
tx.exec,
)
}
func (tx *Transaction) exec(
ctx context.Context,
query string,
params map[string]*Ydb.TypedValue,
collectRowsFunc func([]*Ydb.Value) error,
timeout time.Duration,
commit bool,
) (*Result, error) {
if tx.finish {
return nil, ErrTxFinished
}
if commit {
defer func() {
tx.finish = true
tx.cleanup()
}()
}
txControl := &Ydb_Query.TransactionControl{
// send last exec with commit
CommitTx: commit,
}
if tx.id == "" {
// begin tx
txControl.TxSelector = &Ydb_Query.TransactionControl_BeginTx{
BeginTx: tx.settings,
}
} else {
// continue tx
txControl.TxSelector = &Ydb_Query.TransactionControl_TxId{
TxId: tx.id,
}
}
var (
qCancel context.CancelFunc
)
if timeout > 0 {
ctx, qCancel = context.WithDeadline(ctx, time.Now().Add(timeout))
}
stream, cancel, err := tx.sess.Exec(ctx, query, params, txControl)
if qCancel != nil {
// if ctx was overwritten, then cancel() inherits from qCancel()
cancel = qCancel
}
if err != nil {
cancel()
return nil, err //nolint:wrapcheck //unnecessary
}
res := newResult(stream, cancel, tx.logger, collectRowsFunc)
if err = res.recv(); err != nil {
return nil, errors.Join(ErrResult, err)
}
tx.id = res.TxID()
tx.logger.Trace("received tx result", "txID", tx.id)
return res, nil
}
package query
import (
"context"
"time"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
)
type (
txExecFunc func(
context.Context,
string,
map[string]*Ydb.TypedValue,
func([]*Ydb.Value) error,
time.Duration,
bool,
) (*Result, error)
TxQuery struct {
collectRowsFunc func([]*Ydb.Value) error
txExecFunc txExecFunc
params map[string]*Ydb.TypedValue
content string
timeout time.Duration
commit bool
}
)
func newTxQuery(content string, eF txExecFunc) *TxQuery {
return &TxQuery{
content: content,
txExecFunc: eF,
}
}
func (q *TxQuery) Commit() *TxQuery {
q.commit = true
return q
}
func (q *TxQuery) Collect(collectRowsFunc func([]*Ydb.Value) error) *TxQuery {
q.collectRowsFunc = collectRowsFunc
return q
}
func (q *TxQuery) Param(name string, val *Ydb.TypedValue) *TxQuery {
if q.params == nil {
q.params = make(map[string]*Ydb.TypedValue)
}
q.params[name] = val
return q
}
func (q *TxQuery) Params(params map[string]*Ydb.TypedValue) *TxQuery {
q.params = params
return q
}
func (q *TxQuery) Timeout(timeout time.Duration) *TxQuery {
q.timeout = timeout
return q
}
func (q *TxQuery) Exec(ctx context.Context) (*Result, error) {
return q.txExecFunc(ctx, q.content, q.params, q.collectRowsFunc, q.timeout, q.commit)
}
package types
import "github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
func Bool(val bool) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_BOOL}},
Value: &Ydb.Value{Value: &Ydb.Value_BoolValue{BoolValue: val}},
}
}
func Int32(val int32) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_INT32}},
Value: &Ydb.Value{Value: &Ydb.Value_Int32Value{Int32Value: val}},
}
}
func Uint32(val uint32) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT32}},
Value: &Ydb.Value{Value: &Ydb.Value_Uint32Value{Uint32Value: val}},
}
}
func Int64(val int64) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_INT64}},
Value: &Ydb.Value{Value: &Ydb.Value_Int64Value{Int64Value: val}},
}
}
func Uint64(val uint64) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UINT64}},
Value: &Ydb.Value{Value: &Ydb.Value_Uint64Value{Uint64Value: val}},
}
}
func Float(val float32) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_FLOAT}},
Value: &Ydb.Value{Value: &Ydb.Value_FloatValue{FloatValue: val}},
}
}
func Double(val float64) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_DOUBLE}},
Value: &Ydb.Value{Value: &Ydb.Value_DoubleValue{DoubleValue: val}},
}
}
func UTF8(val string) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_UTF8}},
Value: &Ydb.Value{Value: &Ydb.Value_TextValue{TextValue: val}},
}
}
func Text(val string) *Ydb.TypedValue {
return &Ydb.TypedValue{
Type: &Ydb.Type{Type: &Ydb.Type_TypeId{TypeId: Ydb.Type_STRING}},
Value: &Ydb.Value{Value: &Ydb.Value_TextValue{TextValue: val}},
}
}