Source code for ccvm_simulators.post_processor.adam

from .post_processor import PostProcessor, MethodType
from .box_qp_model import BoxQPModel
import time
import torch
import tqdm


[docs] class PostProcessorAdam(PostProcessor): """A concrete class that implements the PostProcessor interface .""" def __init__(self): self.pp_time = 0 self.method_type = MethodType.Adam
[docs] def postprocess( self, c, q_matrix, v_vector, lower_clamp=0.0, upper_clamp=1.0, num_iter=1, device="cpu", ): """Post processing using Adam method. Args: c (torch.tensor): The values for each variable of the problem in the solution found by the solver. q_matrix (torch.tensor): The Q matrix describing the BoxQP problem. v_vector (torch.tensor): The V vector describing the BoxQP problem. lower_clamp (float, optional): Lower bound of the box constraints. Defaults to 0.0. upper_clamp (float, optional): Upper bound of the box constraints. Defaults to 1.0. num_iter (int, optional): The number of iterations. Defaults to 1. device (str, optional): Defines which GPU (or the CPU) to use. Defaults to "cpu". Returns: torch.tensor: The values for each variable of the problem in the solution found by the solver after post-processing. """ start_time = time.time() try: if not torch.is_tensor(c): raise TypeError("parameter c must be a tensor") if not torch.is_tensor(q_matrix): raise TypeError("parameter q_matrix must be a tensor") if not torch.is_tensor(v_vector): raise TypeError("parameter v_vector must be a tensor") (batch_size, _) = c.size() model = BoxQPModel(c, self.method_type) except Exception as e: raise e optimizer = torch.optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.99)) for _ in tqdm.tqdm(range(num_iter)): loss = model(q_matrix, v_vector) loss.backward(torch.Tensor([1] * batch_size).to(device)) optimizer.step() optimizer.zero_grad() model.params = torch.nn.Parameter( torch.clamp(model.params, lower_clamp, upper_clamp) ) end_time = time.time() self.pp_time = end_time - start_time return model.params.detach()