import gymnasium as gym
import os
from tqdm import tqdm
import torch
class CartPole(torch.nn.Module):
def __init__(self,):
super(CartPole, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(4,64),
torch.nn.ReLU(),
torch.nn.Linear(64,2),
)
def forward(self, x):
x = self.model(x)
return x
def run(model, episodes):
video_length = episodes
env = gym.make("CartPole-v1", render_mode="human")
obs, _ = env.reset()
total_reward = 0.0
with torch.no_grad():
for i in tqdm(range(video_length+1)):
x = torch.tensor(obs).float().unsqueeze(0).to('cuda')
action = model(x).argmax(dim=-1).item()
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
total_reward+=reward
env.close()
print(f"total reward : {total_reward}")
model = CartPole()
model.load_state_dict(torch.load(
os.path.join(os.getcwd(),"99.61_99_policy_net.pth")
)['model_state_dict'])
model.to("cuda")
model.eval()
run(model=model, episodes=500)