Generic Concurrency в Go

f6ce7c3c077415596e1f55c259992d0c.jpg

Привет, гоферы!

В этой статье я хочу поделиться мыслями и идеями, которые у меня накопились за время работы с дженериками в Go, и в частности о том, как шаблоны многозадачности могут стать более удобными и переиспользуемыми с помощью дженериков.

TL; DR

Дженерики и горутины (и итераторы в будущем) — это отличные инструменты, которые мы можем использовать для параллельной обработки данных в общем виде.

В этой статье мы рассмотрим возможности их совместного использования.

Вступление

Давайте сформируем контекст и рассмотрим несколько простых примеров, чтобы увидеть, какую проблему решают дженерики и как мы можем встроить их в существующую модель многозадачности в Go.

В этой статье мы будем говорить об отображении (map()) коллекций или последовательностей элементов. Таким образом,  отображение — это процесс, который приводит к формированию новой коллекции элементов, где каждый элемент является результатом вызова некоторой функции f() с соответствующим элементом из исходной коллекции.

Эра Pre-Generics

Давайте рассмотрим простую функцию отображения целочисленных чисел (которую в коде Go мы будем называть transform(), чтобы избежать путанницы со встроенным типом map):

func transform([]int, func(int) int) []int

Пример реализации

func transform(xs []int, f func(int) int) []int {
    ret := make([]int, len(xs))
    for i, x := range xs {
        ret[i] = f(x)
    }
    return ret
}

Использование такой функции будет выглядеть так:

// Output: [1, 4, 9]
transform([]int{1, 2, 3}, func(n int) int {
    return n * n
})

Предположим, что теперь мы хотим отобразить целые числа в строки. Чтобы это сделать, мы просто можем определить transform() немного иначе:

func transform([]int, func(int) string) []string

Теперь мы можем использовать это следующим образом:

// Output: ["1", "2", "3"]
transform([]int{1, 2, 3}, strconv.Itoa) 

Как насчет другой функции, которая будет возвращать признак четности числа? Просто еще одна небольшая корректировка:

func transform([]int, func(int) bool) []bool

Теперь использование функции будет выглядеть так:

// Output: [false, true, false]
transform([]int{1, 2, 3}, func(n int) bool {
    return n % 2 == 0
})

Обобщая изменения в реализации transform(), которые мы сделали выше для каждого конкретного случая использования, можно сказать, что независимо от типов, с которыми работает функция, она делает ровно то же самое снова и снова. Если бы мы хотели сгенерировать код для каждого типа используя шаблоны text/template, мы могли бы сделать это так:

func transform_{{ .A }}_{{ .B }}([]{{ .A }}, func({{ .A }}) {{ .B }}) []{{ .B }}

// transform_int_int([]int, func(int) int) []int
// transform_int_string([]int, func(int) string) []string
// transform_int_bool([]int, func(int) bool) []bool

Раньше подобные шаблоны действительно использовались для генерации «generic» кода. Проект genny — один из примеров.

Эра Generics

Благодаря дженерикам, теперь мы можем параметризовать функции и типы с помощью параметров типа и определить transform() следующим образом:

func transform[A, B any]([]A, func(A) B) []B

И реализация изменится совсем чуть-чуть!

func transform[A, B any](xs []A, f func(A) B) []B {
    ret := make([]B, len(xs))
    for i, x := range xs {
        ret[i] = f(x)
    }
    return ret
}

Теперь мы можем использовать transform() для любых типов аргументов и результатов (предполагается, что square(int) int и isEven(int) bool определены где-то выше в пакете):

transform([]int{1, 2, 3}, square)       // [1, 4, 9]
transform([]int{1, 2, 3}, strconv.Itoa) // ["1", "2", "3"]
transform([]int{1, 2, 3}, isEven)       // [false, true, false]

Параллельное отображение

Давайте перейдем к основной теме этой статьи и рассмотрим шаблоны многозадачности, которые мы можем использовать совместно с дженериками, и которые станут только лучше от такого использования.

Пакет x/sync/errgroup

Прежде чем погрузиться в код, предлагаю немного отвлечся и взглянуть на очень популярную в Go библиотеку golang.org/x/sync/errgroup. Вкратце, она позволяет запускать горутины для выполнения разных задач и ожидать их завершения или ошибки.

Предполагается, что библиотека используется следующим образом:

// Create workers group and a context which will get canceled if any of the
// tasks fails.
g, gctx := errgroup.WithContext(ctx)
g.Go(func() error {
	return doSomeFun(gctx)
})
g.Go(func() error {
	return doEvenMoreFun(gctx)
})
if err := g.Wait(); err != nil {
	// handle error
}

Причина, по которой я упомянул errgroup в том, что если посмотреть на то, что делает библиотека с немного другой и более общей точки зрения, то мы увидим, что по сути, она является тем же механизмом отображения. errgroup параллельно отображает набор задач (функций) с результатом их выполнения, а так же предоставляет обобщенный способ обработки и «всплытия» ошибок c прерыванием уже выполняющихся задач (через отмену context.Context).

В этой статье мы хотим создать что-то подобное, и, как намекает частое использование слов «общий» и «обобщенный», мы собираемся сделать это в общем виде.

Наивная реализация

Возвращаясь к функции transform(). Допустим, что все вызовы f() могут выполняться параллельно, не ломая выполнение нашей программы. Тогда мы можем начать с этой наивной многозадачной реализации:

func transform[A, B any](as []A, f func(A) B) []B {
	bs := make([]B, len(as))

	var wg sync.WaitGroup
	for i := 0; i < len(as); i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			bs[i] = f(as[i])
		}(i)
	}
	wg.Wait()

	return bs
}

Мы запускаем горутину для каждого входного элемента и вызываем f(elem). Затем мы сохраняем результат по соответствующему индексу в «общем» слайсе bs. Никакого контекста, прерывания, или ошибок — такая реализация не кажется полезной в чем-либо, кроме простых вычислений.

Отмена контекста

В реальном мире большинство проблем требующих многозадачности, особенно те, которые включают работу с i/o, будут контролироваться экземпляром context.Context. Если есть контекст, то может быть его тайм-аут или отмена. Давайте посмотрим на это следующим образом:

func transform[A, B any](
	ctx context.Context,
	as []A,
	f func(context.Context, A) (B, error),
) (
	[]B,
	error,
) {
	bs := make([]B, len(as))
	es := make([]error, len(as))

	subctx, cancel := context.WithCancel(ctx)
	defer cancel()

	var wg sync.WaitGroup
	for i := 0; i < len(as); i++ {
		wg.Add(1)
		go func(i int) {
			defer wg.Done()
			bs[i], es[i] = f(subctx, as[i])
			if es[i] != nil {
				cancel()
			}
		}(i)
	}
	wg.Wait()

	err := errors.Join(es...)
	if err != nil {
		return nil, err
	}
	return bs, nil
}

Теперь у нас есть еще один «общий» слайс es для хранения потенциальных ошибок f(). Если какой-либо вызов f() завершается ошибкой, мы отменяем контекст всего вызова transform() и ожидаем, что каждый уже выполняющийся вызов f() учитывает отмену контекста и, если она происходит, прерывает выполнение как можно скорее.

Ограничение многозадачности

На самом деле, мы не можем слишком много предполагать о f(). Пользователи transform() могут захотеть ограничить количество одновременных вызовов f(). Например,  f() может отображать url на результат http запроса. Без каких-либо ограничений мы можем перегрузить сервер или забанить самих себя.

Давайте пока не будем думать о структуре параметров и просто добавим аргумент параллелизма int в аргументы функции.

На этом этапе нам нужно перейти от использования sync.WaitGroup к семафору chan, поскольку мы хотим контролировать (максимальное) количество одновременно выполняющихся горутин, а также обрабатывать отмену контекста, используя select.

func transform[A, B any](
	ctx context.Context,
	parallelism int,
	as []A,
	f func(context.Context, A) (B, error),
) (
	[]B,
	error,
) {
	bs := make([]B, len(as))
	es := make([]error, len(as))

	// FIXME: if the given context is already cancelled, no worker will be
	// started but the transform() call will return bs, nil.
	subctx, cancel := context.WithCancel(ctx)
	defer cancel()

	sem := make(chan struct{}, parallelism)
sched:
	for i := 0; i < len(as); i++ {
		// We are checking the sub-context cancellation here, in addition to
		// the user-provided context, to handle cases where f() returns an
		// error, which leads to the termination of transform.
		if subctx.Err() != nil {
			break
		}
		select {
		case <-subctx.Done():
			break sched

		case sem <- struct{}{}:
			// Being able to send a tick into the channel means we can start a
			// new worker goroutine. This could be either due to the completion
			// of a previous goroutine or because the number of started worker
			// goroutines is less than the given parallism value.
		}
		go func(i int) {
			defer func() {
				// Signal that the element has been processed and the worker
				// goroutine has completed.
				<-sem
			}()
			bs[i], es[i] = f(subctx, as[i])
			if es[i] != nil {
				cancel()
			}
		}(i)
	}
	// Since each goroutine reads off one tick from the semaphore before exit,
	// filling the channel with artificial ticks makes us sure that all started
	// goroutines completed their execution.
	//
	// FIXME: for the high values of parallelism this loop becomes slow.
	for i := 0; i < cap(sem); i++ {
		// NOTE: we do not check the user-provided context here because we want
		// to return from this function only when all the started worker
		// goroutines have completed. This is to avoid surprising users with
		// some of the f() function calls still running in the background after
		// transform() returns.
		//
		// This implies f() should respect context cancellation and return as
		// soon as its context gets cancelled.
		sem <- struct{}{}
	}

	err := errors.Join(es...)
	if err != nil {
		return nil, err
	}
	return bs, nil
}

Для этой и последующих итераций tranform() мы могли бы оставить реализацию в том виде, в котором она сейчас есть, и оставить рассматриваемые проблемы на откуп реализации f(). Например, мы могли бы просто запустить N горутин вне зависимости от ограничений многозадачности и позволить пользователю transform() частично сериализовать их так, как он хочет. Это потребовало бы накладных расходов на запуск N горутин вместо P (где P — это ограничение «параллелизма», которое может быть гораздо меньше, чем N). Это также подразумевало бы некоторые накладные расходы на синхронизацию горутин, в зависимости от используемого механизма. Поскольку все это излишне, мы продолжаем реализацию the hard way, но во многих случаях эти усложнения являются необязательными.

Пример реализации на стороне пользователя

// Initialised x/time/rate.Limiter instance.
var lim *rate.Limiter
transform(ctx, as, func(_ context.Context, url string) (int, error) {
    if err := lim.Wait(ctx); err != nil {
        return 0, err
    }

    // Process url.

    return 42, nil
})

Переиспользование горутин

В предыдущей итерации мы запускали горутины для каждой задачи, но не более parallelism горутин одновременно. Это выявляет ещё одну интересную деталь — пользователи могут захотеть иметь собственный контекст выполнения для каждой горутины. Предположим, что у нас есть N задач с максимумом P выполняющихся одновременно (и P может быть значительно меньше N). Если каждая задача требует какой-либо подготовки ресурсов, например, большой аллокации памяти, создания сессии базы данных или, может быть, запуска однопоточной Cgo «сопрограммы», то было бы логично подготовить только P ресурсов и переиспользовать их через контекст.

Как и выше, давайте оставим структуру передачи параметров в стороне.

func transform[A, B any](
	ctx context.Context,
	prepare func(context.Context) (context.Context, context.CancelFunc),
	parallelism int,
	as []A,
	f func(context.Context, A) (B, error),
) (
	[]B,
	error,
) {
	bs := make([]B, len(as))
	es := make([]error, len(as))

	// FIXME: if the given context is already cancelled, no worker will be
	// started but the transform() call will return bs, nil.
	subctx, cancel := context.WithCancel(ctx)
	defer cancel()

	sem := make(chan struct{}, parallelism)
	wrk := make(chan int)
sched:
	for i := 0; i < len(as); i++ {
		// We are checking the sub-context cancellation here, in addition to
		// the user-provided context, to handle cases where f() returns an
		// error, which leads to the termination of transform.
		if subctx.Err() != nil {
			break
		}
		select {
		case <-subctx.Done():
			break sched

		case wrk <- i:
			// There is an idle worker goroutine that is ready to process the
			// next element.
			continue

		case sem <- struct{}{}:
			// Being able to send a tick into the channel means we can start a
			// new worker goroutine. This could be either due to the completion
			// of a previous goroutine or because the number of started worker
			// goroutines is less than the given parallism value.
		}
		go func(i int) {
			defer func() {
				// Signal that the element has been processed and the worker
				// goroutine has completed.
				<-sem
			}()

			// Capture the subctx from the dispatch loop. This prevents
			// overriding it if the given prepare() function is not nil.
			subctx := subctx
			if prepare != nil {
				var cancel context.CancelFunc
				subctx, cancel = prepare(subctx)
				defer cancel()
			}
			for {
				bs[i], es[i] = f(subctx, as[i])
				if es[i] != nil {
					cancel()
					return
				}
				var ok bool
				i, ok = <-wrk
				if !ok {
					// Work channel has been closed, which means we will not
					// get any new tasks for this worker and can return.
					break
				}
			}
		}(i)
	}
	// Since each goroutine reads off one tick from the semaphore before exit,
	// filling the channel with artificial ticks makes us sure that all started
	// goroutines completed their execution.
	//
	// FIXME: for the high values of parallelism this loop becomes slow.
	for i := 0; i < cap(sem); i++ {
		// NOTE: we do not check the user-provided context here because we want
		// to return from this function only when all the started worker
		// goroutines have completed. This is to avoid surprising users with
		// some of the f() function calls still running in the background after
		// transform() returns.
		//
		// This implies f() should respect context cancellation and return as
		// soon as its context gets cancelled.
		sem <- struct{}{}
	}

	err := errors.Join(es...)
	if err != nil {
		return nil, err
	}
	return bs, nil
}

На этом этапе мы запускаем до P горутин и распределяем задачи между ними с помощью канала wrk без буфера. Мы не используем буфер, потому что мы хотим получать обратную связь о том, есть ли в данный момент бездействующие горутины или стоит ли нам рассмотреть возможность запуска новой. Как только все задачи обработаны или любой из вызовов f() завершается ошибкой, мы сигнализируем всем горутинам о необходимости прекратить дальнейшее выполнение (с помощью close(wrk)).

Как и в предыдущем разделе, этого можно добиться на уровне f(), например, используя sync.Poolf() может взять ресурс (или создать его, если в пуле нет свободных) и вернуть его, когда он больше не нужен. Поскольку набор горутин фиксирован, вероятность того, что ресурсы будут иметь хорошую локальность CPU велика, поэтому накладные расходы могут быть минимальными.

Пример реализации на стороне пользователя

// Note that this snippet assumes `transform()` can limit its concurrency.
var pool sync.Pool
transform(ctx, 8, as, func(_ context.Context, userID string) (int, error) {
    sess := pool.Get().(*db.Session)
    if sess == nil {
        // Initialise database session.
    }
    defer pool.Put(sess)

    // Process userID.

    return 42, nil
})

Обобщение transform ()

До сих пор наш фокус был на отображении слайсов, что во многих случаях достаточно. Однако, что, если мы хотим отображать типы map или chan? Можем ли мы отображать все, по чему можно итерироваться с помощью range? И, сравнивая с циклом for, действительно ли нам всегда нужно отображать значения?

Это интересные вопросы, которые приводят нас к мысли о том, что мы можем обобщить наш подход к многозадачной итерации. Мы можем определить функцию более «низкого уровня», которая будет вести себя почти так же, но будет делать немного меньше предположений о входных и выходных данных. Тогда потребуется не так много, чтобы реализовать более специфичную функцию transform(), которая использовала бы «низкоуровневую» функцию итерации. Давайте назовем такую функцию iterate() и определим входные и выходные данные функциями, а не конкретными типами данных. Мы будем извлекать входные элементы с помощью pull() и отправлять результаты обратно пользователю с помощью push(). Таким образом, пользователь iterate() сможет контролировать способ предоставления входных элементов и обработку результатов.

Мы также должны подумать о том, какие результаты iterate() будет передавать пользователю. Поскольку мы планируем сделать отображение входных элементов опциональным,  (B, error) больше не кажется единственным правильным и очевидным вариантом. На самом деле это довольно неоднозначный вопрос, и, возможно, в большинстве случаев явное возвращение ошибки было бы предпочтительнее. Однако, семантически это не имеет большого смысла, так как результат f() просто передается в вызов push() без какой-либо обработки, что означает, что у iterate() нет никаких ожиданий или зависимости относительно результата. Другими словами, результат имеет смысл только для функции push(), которую предоставляет пользователь. Кроме того, единственный возвращаемый параметр будет лучше работать с итераторами Go, которые мы рассмотрим в конце этой статьи. Имея это в виду, давайте попробуем уменьшить количество возвращаемых параметров до одного. Так же, поскольку мы планируем передавать результаты через вызов функции, нам скорее всего нужно делать это последовательно — transform() и iterate() уже имеют всю необходимую синхронизацию внутри, поэтому мы можем избавить пользователя от необходимости дополнительной синхронизации при сборе результатов.

Еще один момент, о котором следует подумать, это обработка ошибок — выше мы не связывали ошибку с входным элементом, обработка которого её вызывала. Очевидно, что f() может обернуть ошибку за нас, но более правильно для f() было бы не иметь предположений о том, как её будут вызывать. Другими словами,  f() не должна предполагать, что её вызывают как аргумент iterate(). Если она была вызвана с одним элементом a, внутри f() нет смысла оборачивать a в ошибку, так как очевидно, что именно этот a и вызвал ошибке. Этот принцип приводит нас к другому наблюдению — любое возможное связывание входного элемента с ошибкой (или любым другим результатом) также должно происходить во время выполнения push(). По тем же причинам, функция push() должна контролировать итерацию и решать, должна ли итерация быть прервана в случае ошибки.

Такой дизайн iterate() естественным образом обеспечивает flow control — если пользователь делает что-то медленное внутри функции push(), то другие горутины в конечном итоге приостановят обработку новых элементов. Это произойдет потому, что они будут заблокированы отправляя результаты своего вызова f() в канал res, который читает функция, вызывающая push().

Поскольку код функции становится слишком большим, давайте рассмотрим его по частям.

Интерфейс

func iterate[A, B any](
	ctx context.Context,
	prepare func(context.Context) (context.Context, context.CancelFunc),
	parallelism int,
	pull func() (A, bool),
	f func(context.Context, A) B,
	push func(A, B) bool,
) (err error) {

Теперь аргументы функции и возвращаемые параметры больше не имеют входных и выходных слайсов, как это было для функции transform(). Вместо этого входные элементы извлекаются с помощью вызова функции pull(), а результаты возвращаются пользователю с помощью вызова функции push(). Обратите внимание, что функция push() возвращает параметр bool, который контролирует итерацию — как только возвращается false, ни одного вызова push() больше произойдет, а контекст всех уже выполняющихся f() будет отменен. Функция iterate() возвращает только ошибку, которая может быть non-nil только при прекращении итерации из-за отмены ctx переданного в аргументах iterate(). В противном случае не будет способа узнать, почему прекратилась итерация.

Не смотря на то, что есть всего три случая, когда итерация может быть прервана:

  • pull() вернул false, что означает, что больше нет элементов для обработки.

  • push() вернул false, что означает, что пользователю больше не нужны никакие результаты.

  • ctx переданный в аргументах iterate() был отменён.

Без усложнений пользовательского кода трудно сказать, были ли обработаны все элементы до отмены ctx.

Пример

Давайте представим, что мы хотим реализовать многозадачный forEach(), который использует iterate():

func forEach[A any](
	ctx context.Context,
	in []A,
	f func(context.Context, A) error,
) (err error) {
	var i int
	iterate(ctx, nil, 0,
		func() (_ A, ok bool) {
			if i == len(in) {
				return
			}
			i++
			return in[i-1], true
		},
		f,
		func(_ A, e error) bool {
			err = e
			return e == nil
		},
	)
	if err == nil {
		// BUG: if we returned from `iterate()` call we either processed all
		// input _or_ ctx got cancelled and the iteration got interrupted.
		//
		// Simply checking if ctx.Err() is non-nil here is racy, and may
		// provide false faulty result in case when we processed all the input
		// and _then_ context got cancelled.
		//
		// On the other hand, checking here if i == len(in) as a condition of
		// completeness is incorrect, as we might pull the last element to
		// process and _then_ got interrupted by the context cancelation.
		//
		// So if iterate() doesn't return an error, one should track each
		// element processing state in `f()` call wrapper to correctly
		// distinguish cases above.
		err = ctx.Err()
	}
	return
}

Пролог

	// Create sub-context for the dispatch loop goroutine so we can stop it
	// once the user wants to stop the iteration.
	subctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// result represents input element A and the result B caused by applying
	// the given function f() to A.
	type result struct {
		a A
		b B
	}
	// loopInfo contains the dispatch loop state.
	//
	// The dispatch goroutine below signals current goroutine about the loop
	// termination by sending loopInfo to the term channel below. The current
	// goroutine uses it to understand how many elements have been dispatched
	// for processing to decide for how many results to await.
	type loopInfo struct {
		dispatched int
		err        error
	}
	// These channels are receive-only for the current goroutine and send-only
	// in the dispatch goroutine. For the sake of readability there is no type
	// constraints added.
	var (
		res  = make(chan result)
		term = make(chan loopInfo, 1)
	)

	// This wait group is used to track completion of worker goroutines started
	// by the dispatch goroutine.
	var wg sync.WaitGroup

В предыдущих версиях transform() мы сохраняли результат по индексу, соответствующему входному элементу в слайс результатов, что выглядело как bs[i] = f(as[i]). Теперь, при функциональном вводе и выводе такой подход невозможен. Поэтому, как только у нас есть результат обработки любого элемента, нам, скорее всего, нужно немедленно отправить его пользователю с помощью push(). Вот почему мы хотим иметь две горутины для распределения входных элементов между воркерами и отправки результатов пользователю — пока мы распределяем входные элементы, мы можем получить результат обработки уже распределенных ранее.

Горутина распределения элементов

	// Start the dispatch goroutine. Its purpose is to control the worker
	// goroutines, dispatch input elements among the workers, and eventually
	// signal the current goroutine about the dispatch loop termination.
	go func() {
		// wrk is a channel of input elements. It is send-only for the dispatch
		// gorouine and receive-only for the worker goroutines.
		wrk := make(chan A)

		var loop loopInfo
		defer func() {
			// Signal the workers there are no more elements to dispatch.
			close(wrk)
			// Report the dispatch loop state to the parent goroutine.
			term <- loop
		}()

		var workersCount int
		// We use a _closed_ channel here to make the select below to always be
		// able to receive from it and start up to the given parallelism number
		// of goroutines. Once workersCount == parallelism, we set the variable
		// to nil so that the select cannot read from it after.
		//
		// This is needed to:
		// - Support the special case when parallelism is 0, so that there are
		//   no limits on the number of workers.
		// - Awoid wasting time "corking" the semaphore channel while waiting
		//   for all started goroutines to complete, especially if given a
		//   large parallelism value.
		sem := make(chan struct{})
		close(sem)

Цикл распределения элементов

		for {
			if err := subctx.Err(); err != nil {
				loop.err = err
				return
			}
			a, ok := pull()
			if !ok {
				// No more input elements.
				return
			}
			if parallelism != 0 && workersCount == parallelism {
				// Prevent starting more workers.
				sem = nil
			}
			select {
			case <-subctx.Done():
				loop.err = ctx.Err()
				return

			case wrk <- a:
				// There is an idle worker goroutine that is ready to process
				// the next element.
				loop.dispatched++
				continue

			case <-sem:
				// Being able to _receive_ a tick from the channel means we can
				// start a new worker goroutine.
				loop.dispatched++
			}

			workersCount++
			wg.Add(1)

Горутина выполнения функции

			go func(a A) {
				defer wg.Done()

				// Capture the subctx from the topmost scope. This prevents
				// overriding it if the given prepare() function is not nil.
				subctx := subctx
				if prepare != nil {
					var cancel context.CancelFunc
					subctx, cancel = prepare(subctx)
					defer cancel()
				}
				for {
					r := result{a: a}
					r.b = f(subctx, a)
					select {
					case res <- r:
					case <-subctx.Done():
						// If the context is cancelled, it means no more
						// results are expected.
						return
					}
					var ok bool
					a, ok = <-wrk
					if !ok {
						break
					}
				}
			}(a)

Сбор результатов

		}
	}()

collect:
	// Wait for the results sent by the worker goroutines.
	//
	// Note the initial -1 value for the num variable since the number of
	// elements pulled and dispatched is unknown yet. We weill be notified by
	// the dispatch gorouine once the input ends or the iteration is
	// terminated.
	for i, num := 0, -1; num == -1 || i < num; {
		select {
		case <-ctx.Done():
			// We need to explicitly handle _parent_ context cancellation here
			// because it's an external interruption for us. We ignore the
			// dispatch loop termination event and stop to receive and push
			// results unconditionally.
			if err == nil {
				err = ctx.Err()
			}
			break collect

		case res := <-res:
			if !push(res.a, res.b) {
				// The user wants to stop the iteration. Signal the dispatch
				// loop about this. Note that in this case, we ignore the term
				// channel message and not return any error.
				cancel()
				break collect
			}
			i++

		case loop := <-term:
			// Dispatch loop has now terminated, and we now know the maximum
			// number of results we need receive in this loop.
			num = loop.dispatched
			err = loop.err
		}
	}

	// NOTE: we unconditionally wait for all goroutines to complete in order to
	// return to a clean state. To avoid uninterruptable sleep here users are
	// required to respect context cancellation in the provided f().
	wg.Wait()

	return err
}

Стоит отметить, что результаты возвращаются пользователю в случайном порядке — не в том порядке, в котором они были извлечены. Это ожидаемо, поскольку мы обрабатываем их параллельно.

Субъективное замечание: совместное использование sync.WaitGroup и канала sem в данном случае является редким исключением, когда использование обоих механизмов синхронизации в одном коде оправдано. Я считаю, что в большинстве случаев, если есть канал,  sync.WaitGroup будет излишним, и наоборот.

Фух, вот и все! Это было не просто, но это то, что мы хотели сделать. Давайте посмотрим, как мы можем использовать это.

Полный листинг кода

func iterate[A, B any](
	ctx context.Context,
	prepare func(context.Context) (context.Context, context.CancelFunc),
	parallelism int,
	pull func() (A, bool),
	f func(context.Context, A) B,
	push func(A, B) bool,
) (err error) {
	// Create sub-context for the dispatch loop goroutine so we can stop it
	// once the user wants to stop the iteration.
	subctx, cancel := context.WithCancel(ctx)
	defer cancel()

	// result represents input element A and the result B caused by applying
	// the given function f() to A.
	type result struct {
		a A
		b B
	}
	// loopInfo contains the dispatch loop state.
	//
	// The dispatch goroutine below signals current goroutine about the loop
	// termination by sending loopInfo to the term channel below. The current
	// goroutine uses it to understand how many elements have been dispatched
	// for processing to decide for how many results to await.
	type loopInfo struct {
		dispatched int
		err        error
	}
	// These channels are receive-only for the current goroutine and send-only
	// in the dispatch goroutine. For the sake of readability there is no type
	// constraints added.
	var (
		res  = make(chan result)
		term = make(chan loopInfo, 1)
	)

	// This wait group is used to track completion of worker goroutines started
	// by the dispatch goroutine.
	var wg sync.WaitGroup

	// Start the dispatch goroutine. Its purpose is to control the worker
	// goroutines, dispatch input elements among the workers, and eventually
	// signal the current goroutine about the dispatch loop termination.
	go func() {
		// wrk is a channel of input elements. It is send-only for the dispatch
		// gorouine and receive-only for the worker goroutines.
		wrk := make(chan A)

		var loop loopInfo
		defer func() {
			// Signal the workers there are no more elements to dispatch.
			close(wrk)
			// Report the dispatch loop state to the parent goroutine.
			term <- loop
		}()

		var workersCount int
		// We use a _closed_ channel here to make the select below to always be
		// able to receive from it and start up to the given parallelism number
		// of goroutines. Once workersCount == parallelism, we set the variable
		// to nil so that the select cannot read from it after.
		//
		// This is needed to:
		// - Support the special case when parallelism is 0, so that there are
		//   no limits on the number of workers.
		// - Awoid wasting time "corking" the semaphore channel while waiting
		//   for all started goroutines to complete, especially if given a
		//   large parallelism value.
		sem := make(chan struct{})
		close(sem)

		for {
			if err := subctx.Err(); err != nil {
				loop.err = err
				return
			}
			a, ok := pull()
			if !ok {
				// No more input elements.
				return
			}
			if parallelism != 0 && workersCount == parallelism {
				// Prevent starting more workers.
				sem = nil
			}
			select {
			case <-subctx.Done():
				loop.err = ctx.Err()
				return

			case wrk <- a:
				// There is an idle worker goroutine that is ready to process
				// the next element.
				loop.dispatched++
				continue

			case <-sem:
				// Being able to _receive_ a tick from the channel means we can
				// start a new worker goroutine.
				loop.dispatched++
			}

			workersCount++
			wg.Add(1)

			go func(a A) {
				defer wg.Done()

				// Capture the subctx from the topmost scope. This prevents
				// overriding it if the given prepare() function is not nil.
				subctx := subctx
				if prepare != nil {
					var cancel context.CancelFunc
					subctx, cancel = prepare(subctx)
					defer cancel()
				}
				for {
					r := result{a: a}
					r.b = f(subctx, a)
					select {
					case res <- r:
					case <-subctx.Done():
						// If the context is cancelled, it means no more
						// results are expected.
						return
					}
					var ok bool
					a, ok = <-wrk
					if !ok {
						break
					}
				}
			}(a)
		}
	}()

collect:
	// Wait for the results sent by the worker goroutines.
	//
	// Note the initial -1 value for the num variable since the number of
	// elements pulled and dispatched is unknown yet. We weill be notified by
	// the dispatch gorouine once the input ends or the iteration is
	// terminated.
	for i, num := 0, -1; num == -1 || i < num; {
		select {
		case <-ctx.Done():
			// We need to explicitly handle _parent_ context cancellation here
			// because it's an external interruption for us. We ignore the
			// dispatch loop termination event and stop to receive and push
			// results unconditionally.
			if err == nil {
				err = ctx.Err()
			}
			break collect

		case res := <-res:
			if !push(res.a, res.b) {
				// The user wants to stop the iteration. Signal the dispatch
				// loop about this. Note that in this case, we ignore the term
				// channel message and not return any error.
				cancel()
				break collect
			}
			i++

		case loop := <-term:
			// Dispatch loop has now terminated, and we now know the maximum
			// number of results we need receive in this loop.
			num = loop.dispatched
			err = loop.err
		}
	}

	// NOTE: we unconditionally wait for all goroutines to complete in order to
	// return to a clean state. To avoid uninterruptable sleep here users are
	// required to respect context cancellation in the provided f().
	wg.Wait()

	return err
}

Использование iterate () для transform ()

Чтобы посмотреть, как обобщенная функция итерации может решать проблему отображения, давайте реализуем transform() используя iterate(). Очевидно, что сейчас это выглядит намного компактнее, поскольку мы перенесли всю логику многозадачности в iterate(), а в transform() сосредоточились на сохранении результатов отображения.

func transform[A, B any](
	ctx context.Context,
	prepare func(context.Context) (context.Context, context.CancelFunc),
	parallelism int,
	as []A,
	f func(context.Context, A) (B, error),
) (
	[]B, error,
) {
	bs := make([]B, len(as))
	var (
		i    int
		err1 error
	)
	err0 := iterate(ctx, prepare, parallelism,
		func() (int, bool) {
			i++
			return i - 1, i <= len(as)
		},
		func(ctx context.Context, i int) (err error) {
			bs[i], err = f(ctx, as[i])
			return
		},
		func(i int, err error) bool {
			err1 = err
			return err == nil
		},
	)
	if err := errors.Join(err0, err1); err != nil {
		return nil, err
	}
	return bs, nil
}

Новый errgroup

Чтобы закрыть аналогию с пакетом errgroup, давайте попробуем реализовать что-то подобное с использованием iterate().

type taskFunc func(context.Context) error
func errgroup(ctx context.Context) (
	g func(taskFunc),
	wait func() error,
) {
	task := make(chan taskFunc)
	done := make(chan struct{})

	var (
		err     error
		failure error
	)
	go func() {
		defer close(done)

		// NOTE: we ignore the context preparation here as we don't need it. We
		// also don't limit amount of goroutines running at the same time -- we
		// want each task to start to be executed as soon as possible.
		err = iterate(ctx, nil, 0,
			func() (f taskFunc, ok bool) {
				f, ok = <-task
				return
			},
			func(ctx context.Context, f taskFunc) error {
				return f(ctx)
			},
			func(_ taskFunc, err error) bool {
				if err != nil {
					// Cancel the group work and stop taking new tasks.
					failure = err
					return false
				}
				return true
			},
		)
	}()

	g = func(fn taskFunc) {
		// If wait() wasn't called yet, but a previously scheduled task has
		// failed already, we should ignore the task and avoid deadlock here.
		select {
		case task <- fn:
		case <-done:
		}
	}
	wait = func() error {
		close(task)
		<-done
		return errors.Join(err, failure)
	}
	return
}

Таким образом, использование функции будет очень похоже на использование errgroup:

// Create the workers group and a context which will be canceled if any of the
// tasks fails.
g, wait := errgroup(ctx)
g(func(ctx context.Context) error {
	return doSomeFun(gctx)
})
g(func(ctx context.Context) error {
	return doEvenMoreFun(gctx)
})
if err := wait(); err != nil {
	// handle error
}

Итераторы в Go

Давайте кратко рассмотрим ближайшее будущее Go, и то, как оно может повлиять на идеи рассмотренные выше.

С недавним экспериментом range over func (начиная с Go 1.22) возможна итерация range по функциям, совместимыми с типами итератора последовательностей, определенными в пакете iter. Это новая концепция Go, которая, надеюсь, станет частью стандартной библиотеки в будущих версиях. Для получения дополнительной информации, пожалуйста, прочтите range over func proposal, а также предопределяющую статью о сопрограммах в Go от Russ Cox, на основе которой реализован пакет iter.

Сделать iterate() совместимым с iter — проще простого:

func iterate[A, B any](
	ctx context.Context,
	prepare func(context.Context) (context.Context, context.CancelFunc),
	parallelism int,
	seq iter.Seq[A],
	f func(context.Context, A) B,
) iter.Seq2[A, B] {
	return func(yield func(A, B) bool) {
		pull, stop := iter.Pull(seq)
		defer stop()
		iterate(ctx, prepare, parallelism, pull, f, yield)
	}
}

Эксперимент позволяет нам делать потрясающие вещи — итерироваться по результатам параллельно обрабатываемых элементов последовательности в обычном цикле for!

// Assuming the standard library supports iterators.
seq := slices.Sequence([]int{1, 2, 3})

// Output: [1, 4, 9]
for a, b := range iterate(ctx, nil, 0, seq, square) {
	fmt.Println(a, b)
}

Выводы

Я хотел бы, чтобы это было частью стандартной библиотеки Go.

Изначально я думал, что было бы круто оставить этот раздел с одним предложением выше. Но, вероятно, все же стоит сказать несколько слов почему. Я считаю, что подобные библиотеки общего назначения могут быть гораздо лучше восприняты и применены в проектах, если большая часть сообщества согласна с их дизайном и реализацией. Конечно, у нас могут быть разные библиотеки, решающие похожие проблемы, но на мой взгляд, в некоторых случаях большое количество библиотек может приводить к большому количеству разногласий о том, что, когда и как использовать. В некоторых случаях нет ничего плохого в том, чтобы иметь совершенно разные подходы и реализации, но в некоторых случаях это также может означать и вовсе отсутствие полноценного решения. Очень часто библиотеки появляются как конкретное решение какой-то конкретной проблемы, гораздо более конкретное, чем требуется для библиотеки общего применения. Чтобы получить общee решение, дизайн и концепцию было бы здорово обсуждать задолго до начала реализации. Так происходит в open-source foundations или, в случае с Go — в команде разработчиков языка. Наличие инструментов стандартной библиотеки для многозадачной обработки данных кажется естественным развитием Go после пакетов slices,  coro и iter.

Ссылки

© Habrahabr.ru