import gymnasium as gym
import os
from tqdm import tqdm
import torch



class CartPole(torch.nn.Module):
    def __init__(self,):
        super(CartPole, self).__init__()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(4,64),
            torch.nn.ReLU(),
            torch.nn.Linear(64,2),            
        )
    
    def forward(self, x):
        x = self.model(x)
        return x


def run(model, episodes):
    video_length = episodes
    env = gym.make("CartPole-v1", render_mode="human") # human, rgb_array
    obs, _ = env.reset()
    total_reward = 0.0
    with torch.no_grad():
        for i in tqdm(range(video_length+1)):
            x = torch.tensor(obs).float().unsqueeze(0).to('cuda')
            action = model(x).argmax(dim=-1).item()
            
            obs, reward, terminated, truncated, info = env.step(action)
            
            if terminated or truncated:
                obs, _ = env.reset()
            total_reward+=reward

    env.close()
    print(f"total reward : {total_reward}")




model = CartPole()
model.load_state_dict(torch.load(
    os.path.join(os.getcwd(),"99.61_99_policy_net.pth")
)['model_state_dict'])
model.to("cuda")
model.eval()
run(model=model, episodes=500)
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading