Search Engine Architecture: From Sequential to Distributed

Strategy Pattern Overview

MCTS requires many neural network inferences per move — hundreds to thousands of forward passes through the ResNet for each game position. The optimal strategy for managing these inferences depends on the hardware: single CPU, local GPU, multiple processes, or a distributed Ray cluster. Rather than hardcoding one approach, the system uses a strategy pattern: all engines implement the SearchEngine abstract base class, and the PVMCTS facade selects the right one based on configuration.

Four modes are available:

  • Sequential: One simulation at a time, synchronous neural network inference. Used for development, debugging, and CPU-based production serving. Supports native C++ MCTS via CppSearchStrategy.
  • Vectorize: Interleaves multiple tree traversals and batches leaf evaluations into a single GPU call. Optimizes GPU utilization for local training by amortizing inference overhead across multiple simultaneous searches.
  • Multiprocess: Distributes MCTS simulations across a process pool with a central inference server via IPC. Falls back to SequentialEngine when native C++ is enabled (see table below).
  • Ray: Distributed across a Ray cluster. The most complex mode — uses an asynchronous pipelined architecture to overlap CPU tree traversal with remote GPU inference. Deep dive below.

Native C++ MCTS compatibility varies by mode:

ModeNative C++ MCTS
SequentialFull support
RayVia SequentialEngine + CppSearchStrategy with async callbacks
VectorizeNot supported
MultiprocessFalls back to SequentialEngine when native C++ is enabled; otherwise uses MultiprocessEngine

The Core Problem: CPU/GPU Pipeline Stall

In a synchronous (stop-and-wait) approach, the CPU selects leaf nodes, bundles them into a batch, sends the batch to a remote GPU actor, and blocks until the result returns. Only then does the CPU resume tree traversal. This wastes both resources: the GPU sits idle during tree traversal, and the CPU sits idle during inference. With network latency to remote GPU actors, this idle time becomes the dominant bottleneck.

The Solution: Pipelined Architecture

The key insight: overlap CPU work (tree selection and expansion) with GPU work (inference). While the GPU processes batch NN, the CPU selects leaf nodes for batch N+1N+1. This requires solving the duplicate path problem: without protection, multiple concurrent selections would target the same promising leaf before the first result returns, wasting simulations.

The solution has three components: RayAsyncEngine (orchestration), BatchInferenceManager (batching), and RayInferenceClient (dispatch).

Pipelined Async MCTS  —  CPU / GPU Swim Lanes

Step 0 / 0

RayAsyncEngine

The engine orchestrates the main search loop with three phases per iteration:

Selection phase: Traverse the tree via UCB to find leaf nodes. Apply virtual loss (pending_visits += 1 on every node in the path) to discourage future selections from following the same path. Register in-flight bookkeeping (inflight_to_root, inflight_paths, pending_by_root), encode the leaf state (native C++ fast path when available), and enqueue via PendingNodeInfo. Guards prevent redundant work: skip if a leaf is already in-flight, and pause if queue pressure exceeds max_pending (backpressure).

Dispatch phase: Send ready batches to Ray actors (delegates to BatchInferenceManager).

Drain phase: Collect results via ray.wait(num_returns=1) — returns as soon as any batch completes, not waiting for all. For each result: expand the leaf (create children from policy), backup the value to root (sign-flipping propagation), and remove virtual loss (pending_visits -= 1). Freed in-flight slots trigger dispatch_ready() to send more batches immediately.

Source: alphazero/gomoku/pvmcts/search/ray/ray_async.py:218-255

      # alphazero/gomoku/pvmcts/search/ray/ray_async.py:218-255
# Virtual loss + in-flight bookkeeping + enqueue with rollback
try:
    for node in path:
        node.pending_visits += 1
    inflight_to_root[leaf] = root
    inflight_paths[leaf] = path
    pending_by_root[root] += 1

    native_payload = None
    if getattr(self.game, "use_native", False) and leaf.state.native_state is not None:
        features_np = np.array(
            self.game._native_core.write_state_features(  # noqa: SLF001
                leaf.state.native_state
            ),
            copy=False,
            dtype=np.float32,
        )
        state_tensor = torch.as_tensor(features_np, dtype=dtype, device=device)
        native_payload = features_np
    else:
        state_tensor = self._encode_state_to_tensor(
            leaf.state, dtype=dtype, device=device
        )

    is_start_node = leaf is root
    self.manager.enqueue(
        PendingNodeInfo(node=leaf, is_start_node=is_start_node),
        state_tensor,
        native_state=native_payload,
    )
except Exception:
    inflight_to_root.pop(leaf, None)
    inflight_paths.pop(leaf, None)
    if pending_by_root[root] > 0:
        pending_by_root[root] -= 1
    for node in path:
        node.pending_visits = max(0, node.pending_visits - 1)
    raise

    
python

This shows the current failure-safe behavior: the engine applies virtual loss and in-flight bookkeeping inside a guarded block, supports both Python and native C++ state encoding paths, and on exception rolls back all mutable state (maps, counters, and pending_visits) before re-raising.

BatchInferenceManager

The manager accumulates individual leaf tensors into batches and dispatches them to GPU actors. Each queued item carries explicit node metadata via PendingNodeInfo, so inference outputs can be mapped back to tree nodes safely:

Source: alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:13-18, alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:56-61

      class PendingNodeInfo(NamedTuple):
    node: TreeNode
    is_start_node: bool

def enqueue(
    self,
    mapping: PendingNodeInfo,
    tensor: torch.Tensor,
    native_state: object | None = None,
) -> None:

    
python

Batch dispatch itself is controlled by queue size or queue wait timeout:

Source: alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:154-161

      # alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:154-161
def _should_dispatch(self) -> bool:
    """Check whether dispatch conditions are satisfied."""
    if len(self._queue) >= self.batch_size:
        return True
    if self.max_wait_ns <= 0 or self._queue_start_ns is None:
        return False
    elapsed = time.monotonic_ns() - self._queue_start_ns
    return elapsed >= self.max_wait_ns

    
python

The dispatch condition uses OR logic: fire when the queue reaches batch_size or when the queue wait crosses the configured timeout. Internally, timeout is tracked as max_wait_ns against _queue_start_ns (time.monotonic_ns()), derived from user-facing max_wait_ms. This prevents latency spikes when only a few leaves are queued near the end of search.

The manager also enforces max_inflight_batches — when too many batches are in-flight, the selection phase pauses (backpressure: max_pending = inflight_limit x batch_size). This prevents unbounded memory growth from queued tensors.

RayInferenceClient

The client dispatches inference requests to remote Ray GPU actors:

Source: alphazero/gomoku/inference/ray_client.py:289-315

      # alphazero/gomoku/inference/ray_client.py:289-315
def infer_async(
    self,
    states: torch.Tensor,
    native_payload: list[object] | None = None,
    model_slot: str | None = None,
) -> ray.ObjectRef:
    """Run asynchronous inference and return an ObjectRef immediately."""
    if isinstance(states, np.ndarray):
        batch_np = states
    elif states.dim() in (3, 4):
        batch_np = states.detach().cpu().numpy()
    else:
        raise ValueError(f"Unsupported state shape: {states.shape}")

    # Least-loaded actor selection: pick the actor with fewest pending
    # requests to avoid stalling behind a slow actor.
    min_idx = 0
    min_pending = self._pending_counts[0]
    for i in range(1, len(self.actors)):
        if self._pending_counts[i] < min_pending:
            min_pending = self._pending_counts[i]
            min_idx = i
    self._pending_counts[min_idx] += 1
    actor = self.actors[min_idx]
    ref = actor.infer.remote(batch_np, native_payload, model_slot=model_slot)
    self._ref_to_actor[ref] = min_idx
    return ref

    
python

Least-loaded actor selection picks the actor with the fewest pending requests, distributing load evenly across GPU workers. The method accepts optional native_payload and model_slot, and returns a ray.ObjectRef immediately (non-blocking) so the caller proceeds with tree traversal while Ray handles compute.

On the actor side, an AsyncIO consumer loop implements aggressive batching: it blocks on the first item, then non-blocking grabs until batch_size is reached or a deadline expires. A single model(inputs) forward pass runs over the concatenated batch, then results are split back to per-request futures.

Virtual Loss Mechanism

Virtual loss is the mechanism that makes asynchronous MCTS correct, not just fast.

When a leaf is selected and enqueued for inference, every node on the path from root to leaf gets pending_visits += 1. The UCB calculation includes these pending visits:

UCB=Q+cPmax(1,N+Npending)1+Nchild+Nchild_pendingUCB = Q + c \cdot P \cdot \frac{\sqrt{\max(1, N + N_{\text{pending}})}}{1 + N_{\text{child}} + N_{\text{child\_pending}}}

The \max(1, \cdot) guard matches the code path (parent_sqrt = sqrt(max(1, parent_n))) and prevents a zero-root edge case from collapsing the exploration term.

This inflates the "visited" count for nodes along the path, making that path less attractive for the next selection. Combined with the inflight_to_root dict (which prevents selecting the exact same leaf twice), this naturally spreads concurrent selections across diverse paths in the tree.

When inference results return, pending_visits -= 1 on all path nodes. Error handling is critical: rollback uses max(0, pending_visits - 1) to guard against negative counts. A finally block in search() cleans up all outstanding virtual loss on early exit (KeyboardInterrupt, exception), preventing stale pending_visits > 0 from corrupting future searches on the same tree.