Search Engine Architecture: From Sequential to Distributed
Strategy Pattern Overview
MCTS requires many neural network inferences per move — hundreds to thousands of forward passes through the ResNet for each game position. The optimal strategy for managing these inferences depends on the hardware: single CPU, local GPU, multiple processes, or a distributed Ray cluster. Rather than hardcoding one approach, the system uses a strategy pattern: all engines implement the SearchEngine abstract base class, and the PVMCTS facade selects the right one based on configuration.
Four modes are available:
- Sequential: One simulation at a time, synchronous neural network inference. Used for development, debugging, and CPU-based production serving. Supports native C++ MCTS via
CppSearchStrategy. - Vectorize: Interleaves multiple tree traversals and batches leaf evaluations into a single GPU call. Optimizes GPU utilization for local training by amortizing inference overhead across multiple simultaneous searches.
- Multiprocess: Distributes MCTS simulations across a process pool with a central inference server via IPC. Falls back to
SequentialEnginewhen native C++ is enabled (see table below). - Ray: Distributed across a Ray cluster. The most complex mode — uses an asynchronous pipelined architecture to overlap CPU tree traversal with remote GPU inference. Deep dive below.
Native C++ MCTS compatibility varies by mode:
| Mode | Native C++ MCTS |
|---|---|
| Sequential | Full support |
| Ray | Via SequentialEngine + CppSearchStrategy with async callbacks |
| Vectorize | Not supported |
| Multiprocess | Falls back to SequentialEngine when native C++ is enabled; otherwise uses MultiprocessEngine |
The Core Problem: CPU/GPU Pipeline Stall
In a synchronous (stop-and-wait) approach, the CPU selects leaf nodes, bundles them into a batch, sends the batch to a remote GPU actor, and blocks until the result returns. Only then does the CPU resume tree traversal. This wastes both resources: the GPU sits idle during tree traversal, and the CPU sits idle during inference. With network latency to remote GPU actors, this idle time becomes the dominant bottleneck.
The Solution: Pipelined Architecture
The key insight: overlap CPU work (tree selection and expansion) with GPU work (inference). While the GPU processes batch , the CPU selects leaf nodes for batch . This requires solving the duplicate path problem: without protection, multiple concurrent selections would target the same promising leaf before the first result returns, wasting simulations.
The solution has three components: RayAsyncEngine (orchestration), BatchInferenceManager (batching), and RayInferenceClient (dispatch).
RayAsyncEngine
The engine orchestrates the main search loop with three phases per iteration:
Selection phase: Traverse the tree via UCB to find leaf nodes. Apply virtual loss (pending_visits += 1 on every node in the path) to discourage future selections from following the same path. Register in-flight bookkeeping (inflight_to_root, inflight_paths, pending_by_root), encode the leaf state (native C++ fast path when available), and enqueue via PendingNodeInfo. Guards prevent redundant work: skip if a leaf is already in-flight, and pause if queue pressure exceeds max_pending (backpressure).
Dispatch phase: Send ready batches to Ray actors (delegates to BatchInferenceManager).
Drain phase: Collect results via ray.wait(num_returns=1) — returns as soon as any batch completes, not waiting for all. For each result: expand the leaf (create children from policy), backup the value to root (sign-flipping propagation), and remove virtual loss (pending_visits -= 1). Freed in-flight slots trigger dispatch_ready() to send more batches immediately.
Source: alphazero/gomoku/pvmcts/search/ray/ray_async.py:218-255
# alphazero/gomoku/pvmcts/search/ray/ray_async.py:218-255
# Virtual loss + in-flight bookkeeping + enqueue with rollback
try:
for node in path:
node.pending_visits += 1
inflight_to_root[leaf] = root
inflight_paths[leaf] = path
pending_by_root[root] += 1
native_payload = None
if getattr(self.game, "use_native", False) and leaf.state.native_state is not None:
features_np = np.array(
self.game._native_core.write_state_features( # noqa: SLF001
leaf.state.native_state
),
copy=False,
dtype=np.float32,
)
state_tensor = torch.as_tensor(features_np, dtype=dtype, device=device)
native_payload = features_np
else:
state_tensor = self._encode_state_to_tensor(
leaf.state, dtype=dtype, device=device
)
is_start_node = leaf is root
self.manager.enqueue(
PendingNodeInfo(node=leaf, is_start_node=is_start_node),
state_tensor,
native_state=native_payload,
)
except Exception:
inflight_to_root.pop(leaf, None)
inflight_paths.pop(leaf, None)
if pending_by_root[root] > 0:
pending_by_root[root] -= 1
for node in path:
node.pending_visits = max(0, node.pending_visits - 1)
raise
pythonThis shows the current failure-safe behavior: the engine applies virtual loss and in-flight bookkeeping inside a guarded block, supports both Python and native C++ state encoding paths, and on exception rolls back all mutable state (maps, counters, and pending_visits) before re-raising.
BatchInferenceManager
The manager accumulates individual leaf tensors into batches and dispatches them to GPU actors. Each queued item carries explicit node metadata via PendingNodeInfo, so inference outputs can be mapped back to tree nodes safely:
Source: alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:13-18, alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:56-61
class PendingNodeInfo(NamedTuple):
node: TreeNode
is_start_node: bool
def enqueue(
self,
mapping: PendingNodeInfo,
tensor: torch.Tensor,
native_state: object | None = None,
) -> None:
pythonBatch dispatch itself is controlled by queue size or queue wait timeout:
Source: alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:154-161
# alphazero/gomoku/pvmcts/search/ray/batch_inference_manager.py:154-161
def _should_dispatch(self) -> bool:
"""Check whether dispatch conditions are satisfied."""
if len(self._queue) >= self.batch_size:
return True
if self.max_wait_ns <= 0 or self._queue_start_ns is None:
return False
elapsed = time.monotonic_ns() - self._queue_start_ns
return elapsed >= self.max_wait_ns
pythonThe dispatch condition uses OR logic: fire when the queue reaches batch_size or when the queue wait crosses the configured timeout. Internally, timeout is tracked as max_wait_ns against _queue_start_ns (time.monotonic_ns()), derived from user-facing max_wait_ms. This prevents latency spikes when only a few leaves are queued near the end of search.
The manager also enforces max_inflight_batches — when too many batches are in-flight, the selection phase pauses (backpressure: max_pending = inflight_limit x batch_size). This prevents unbounded memory growth from queued tensors.
RayInferenceClient
The client dispatches inference requests to remote Ray GPU actors:
Source: alphazero/gomoku/inference/ray_client.py:289-315
# alphazero/gomoku/inference/ray_client.py:289-315
def infer_async(
self,
states: torch.Tensor,
native_payload: list[object] | None = None,
model_slot: str | None = None,
) -> ray.ObjectRef:
"""Run asynchronous inference and return an ObjectRef immediately."""
if isinstance(states, np.ndarray):
batch_np = states
elif states.dim() in (3, 4):
batch_np = states.detach().cpu().numpy()
else:
raise ValueError(f"Unsupported state shape: {states.shape}")
# Least-loaded actor selection: pick the actor with fewest pending
# requests to avoid stalling behind a slow actor.
min_idx = 0
min_pending = self._pending_counts[0]
for i in range(1, len(self.actors)):
if self._pending_counts[i] < min_pending:
min_pending = self._pending_counts[i]
min_idx = i
self._pending_counts[min_idx] += 1
actor = self.actors[min_idx]
ref = actor.infer.remote(batch_np, native_payload, model_slot=model_slot)
self._ref_to_actor[ref] = min_idx
return ref
pythonLeast-loaded actor selection picks the actor with the fewest pending requests, distributing load evenly across GPU workers. The method accepts optional native_payload and model_slot, and returns a ray.ObjectRef immediately (non-blocking) so the caller proceeds with tree traversal while Ray handles compute.
On the actor side, an AsyncIO consumer loop implements aggressive batching: it blocks on the first item, then non-blocking grabs until batch_size is reached or a deadline expires. A single model(inputs) forward pass runs over the concatenated batch, then results are split back to per-request futures.
Virtual Loss Mechanism
Virtual loss is the mechanism that makes asynchronous MCTS correct, not just fast.
When a leaf is selected and enqueued for inference, every node on the path from root to leaf gets pending_visits += 1. The UCB calculation includes these pending visits:
The \max(1, \cdot) guard matches the code path (parent_sqrt = sqrt(max(1, parent_n))) and prevents a zero-root edge case from collapsing the exploration term.
This inflates the "visited" count for nodes along the path, making that path less attractive for the next selection. Combined with the inflight_to_root dict (which prevents selecting the exact same leaf twice), this naturally spreads concurrent selections across diverse paths in the tree.
When inference results return, pending_visits -= 1 on all path nodes. Error handling is critical: rollback uses max(0, pending_visits - 1) to guard against negative counts. A finally block in search() cleans up all outstanding virtual loss on early exit (KeyboardInterrupt, exception), preventing stale pending_visits > 0 from corrupting future searches on the same tree.