您对华为云开发者网站的整体评价?

非常不满意 非常满意

0

1

2

3

4

5

6

7

8

9

10

*您遇到了哪些问题?(最多选三项)
*您感到满意的原因是?(最多选三项)
*请针对您所遇到的问题给出具体的反馈
200/200
使用PPO算法玩超级马里奥兄弟
你是不是仍旧对马里奥这个游戏记忆犹新_是不是仍旧对过关焦头烂额_训练一款自己的AI_看AI如何过五关斩六将
    1443 12 8662 3671

1024程序员节火热进行中, 参与AI实践,赢取华为手机大奖!详情见: 《1024 程序员节 AI 专区》

使用PPO算法玩“超级马里奥兄弟”

实验目标

通过本案例的学习和课后作业的练习:

  1. 了解PPO算法的基本概念
  2. 了解如何基于PPO训练一个小游戏
  3. 了解强化学习训练推理游戏的整体流程

你也可以将本案例相关的 ipynb 学习笔记分享到 AI Gallery Notebook 版块获得成长值,分享方法请查看此文档

案例内容介绍

在此教程中,我们利用PPO算法来玩“Super Mario Bros”(超级马里奥兄弟)。目前来看,对于绝大部分关卡,智能体都可以在1500个episode内学会过关,您可以在超参数栏输入您想要的游戏关卡和训练算法超参数。

整体流程:创建马里奥环境->构建PPO算法->训练->推理->可视化效果

PPO算法的基本结构

PPO算法有两种主要形式:PPO-Penalty和PPO-Clip(PPO2)。在这里,我们讨论PPO-Clip(OpenAI使用的主要形式)。 PPO的主要特点如下:

  • PPO属于on-policy算法

  • PPO同时适用于离散和连续的动作空间

  • 损失函数 PPO-Clip算法最精髓的地方就是加入了一项比例用以描绘新老策略的差异,通过超参数ϵ限制策略的更新步长:

  • 更新策略:

  • 探索策略 PPO采用随机探索策略。

  • 优势函数 表示在状态s下采取动作a,相较于其他动作有多少优势,如果>0,则当前动作比平均动作好,反之,则差

算法主要流程大致如下:

PPO论文

超级马里奥兄弟游戏环境简介

《超级马里奥兄弟》,是任天堂公司开发并于1985年出品的著名横版过关游戏,游戏的目标在于游历蘑菇王国,并从大反派酷霸王的魔掌里救回桃花公主。马力奥可以在游戏世界收集散落各处的金币,或者敲击特殊的砖块,获得其中的金币或特殊道具。这里一共有8大关(world),每大关有4小关(stage)。

注意事项

  1. 本案例运行环境为 Pytorch-1.0.0,且需使用 GPU 运行,请查看《ModelAtrs JupyterLab 硬件规格使用指南》了解切换硬件规格的方法;

  2. 如果您是第一次使用 JupyterLab,请查看《ModelAtrs JupyterLab使用指导》了解使用方法;

  3. 如果您在使用 JupyterLab 过程中碰到报错,请参考《ModelAtrs JupyterLab常见问题解决办法》尝试解决问题。

实验步骤

1. 程序初始化

第1步:安装基础依赖

!pip install -U pip
!pip install gym==0.19.0
!pip install tqdm==4.48.0
!pip install nes-py==8.1.0
!pip install gym-super-mario-bros==7.3.2
Collecting pip

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)

    100% |████████████████████████████████| 1.7MB 125.8MB/s ta 0:00:01           | 245kB 81.7MB/s eta 0:00:01

Installing collected packages: pip

  Found existing installation: pip 9.0.1

    Uninstalling pip-9.0.1:

      Successfully uninstalled pip-9.0.1

Successfully installed pip-21.3.1

You are using pip version 21.3.1, however version 22.0.4 is available.

You should consider upgrading via the 'pip install --upgrade pip' command.

Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Collecting gym==0.19.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/af/f5/958aa7f6f5b685896dbddf7436d06755bf23b12c65889b156566dd06c514/gym-0.19.0.tar.gz (1.6 MB)

     |████████████████████████████████| 1.6 MB 22.0 MB/s            

  Preparing metadata (setup.py) ... done

Requirement already satisfied: numpy>=1.18.0 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from gym==0.19.0) (1.19.1)

Collecting cloudpickle<1.7.0,>=1.2.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/e7/e3/898487e5dbeb612054cf2e0c188463acb358167fef749c53c8bb8918cea1/cloudpickle-1.6.0-py3-none-any.whl (23 kB)

Building wheels for collected packages: gym

  Building wheel for gym (setup.py) ... done

  Created wheel for gym: filename=gym-0.19.0-py3-none-any.whl size=1664170 sha256=dd9fca8975ea6fd26abb72a31f625cef04233204fb20cfc89fe2777aa483af5b

  Stored in directory: /home/ma-user/.cache/pip/wheels/2e/e9/91/01da0b6c9f0abb530e8bba7cda22efe29bf3efd77287587a83

Successfully built gym

Installing collected packages: cloudpickle, gym

  Attempting uninstall: cloudpickle

    Found existing installation: cloudpickle 0.5.2

    Uninstalling cloudpickle-0.5.2:

      Successfully uninstalled cloudpickle-0.5.2

Successfully installed cloudpickle-1.6.0 gym-0.19.0

Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Collecting tqdm==4.48.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/af/88/7b0ea5fa8192d1733dea459a9e3059afc87819cb4072c43263f2ec7ab768/tqdm-4.48.0-py2.py3-none-any.whl (67 kB)

     |████████████████████████████████| 67 kB 4.8 MB/s             

Installing collected packages: tqdm

  Attempting uninstall: tqdm

    Found existing installation: tqdm 4.28.1

    Uninstalling tqdm-4.28.1:

      Successfully uninstalled tqdm-4.28.1

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.

modelarts 1.4.0 requires configparser<=5.0.2, but you have configparser 5.2.0 which is incompatible.

Successfully installed tqdm-4.48.0

Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Collecting nes-py==8.1.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/77/5f/77764dec1c69b918a8a1b4d7c50456eab45e76ac55d5ab3463818fbcf9bd/nes_py-8.1.0.tar.gz (73 kB)

     |████████████████████████████████| 73 kB 1.2 MB/s             

  Preparing metadata (setup.py) ... done

Requirement already satisfied: gym>=0.10.9 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from nes-py==8.1.0) (0.19.0)

Requirement already satisfied: numpy>=1.12.1 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from nes-py==8.1.0) (1.19.1)

Collecting pyglet>=1.3.2

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/aa/6f/f1e70b50a855e6b18f935a705baf3b5746fa78bc0d26a223e89cab3167e1/pyglet-1.5.23-py3-none-any.whl (1.1 MB)

     |████████████████████████████████| 1.1 MB 64.4 MB/s            

Requirement already satisfied: tqdm>=4.19.5 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from nes-py==8.1.0) (4.48.0)

Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from gym>=0.10.9->nes-py==8.1.0) (1.6.0)

Building wheels for collected packages: nes-py

  Building wheel for nes-py (setup.py) ... done

  Created wheel for nes-py: filename=nes_py-8.1.0-cp36-cp36m-linux_x86_64.whl size=421126 sha256=e6c28fb338ae40d47d0c884490bcd3ed1da08b22b1af0a27c07d2b0da15ba766

  Stored in directory: /home/ma-user/.cache/pip/wheels/f6/78/ce/cbf2372e487bdc3dee680b8454060de903057791730b646964

Successfully built nes-py

Installing collected packages: pyglet, nes-py

Successfully installed nes-py-8.1.0 pyglet-1.5.23

Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Collecting gym-super-mario-bros==7.3.2

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/0e/9b/6afad2bc68c32c647f9433aaa0dafac7e0edff7c940c0c3c67c9ecc6dee7/gym_super_mario_bros-7.3.2-py2.py3-none-any.whl (198 kB)

     |████████████████████████████████| 198 kB 27.6 MB/s            

Collecting nes-py>=8.1.2

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/da/0f/54ca5e0c700517c9ea636db26ab97f2081df0ee0cf491a2154fa6fd23e0a/nes_py-8.1.8.tar.gz (76 kB)

     |████████████████████████████████| 76 kB 11.7 MB/s             

  Preparing metadata (setup.py) ... done

Requirement already satisfied: gym>=0.17.2 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from nes-py>=8.1.2->gym-super-mario-bros==7.3.2) (0.19.0)

Requirement already satisfied: numpy>=1.18.5 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from nes-py>=8.1.2->gym-super-mario-bros==7.3.2) (1.19.1)

Collecting pyglet<=1.5.11,>=1.4.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/9d/be/64fa6401b3c60c5dae09d7ab7eb68ccb0d1cb0a91ddd75b02e64c21c51bd/pyglet-1.5.11-py3-none-any.whl (1.1 MB)

     |████████████████████████████████| 1.1 MB 22.3 MB/s            

Collecting tqdm>=4.48.2

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/81/1c/93a2b77b97cdba15a59c3d2d03e53d3292158d1106d37f579069abd90ece/tqdm-4.63.0-py2.py3-none-any.whl (76 kB)

     |████████████████████████████████| 76 kB 12.2 MB/s             

Requirement already satisfied: cloudpickle<1.7.0,>=1.2.0 in /home/ma-user/anaconda3/envs/Pytorch-1.0.0/lib/python3.6/site-packages (from gym>=0.17.2->nes-py>=8.1.2->gym-super-mario-bros==7.3.2) (1.6.0)

Collecting importlib-resources

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/24/1b/33e489669a94da3ef4562938cd306e8fa915e13939d7b8277cb5569cb405/importlib_resources-5.4.0-py3-none-any.whl (28 kB)

Requirement already satisfied: zipp>=3.1.0 in /home/ma-user/modelarts/modelarts-sdk (from importlib-resources->tqdm>=4.48.2->nes-py>=8.1.2->gym-super-mario-bros==7.3.2) (3.7.0)

Building wheels for collected packages: nes-py

  Building wheel for nes-py (setup.py) ... done

  Created wheel for nes-py: filename=nes_py-8.1.8-cp36-cp36m-linux_x86_64.whl size=421596 sha256=5ec8f76a74b1ac81612ca5f172a8842fc74bb3996708d212e9f38d553a53476e

  Stored in directory: /home/ma-user/.cache/pip/wheels/89/28/09/e1c01934051bb30f78e6139e529fe81b57642cf6b2f72f199b

Successfully built nes-py

Installing collected packages: importlib-resources, tqdm, pyglet, nes-py, gym-super-mario-bros

  Attempting uninstall: tqdm

    Found existing installation: tqdm 4.48.0

    Uninstalling tqdm-4.48.0:

      Successfully uninstalled tqdm-4.48.0

  Attempting uninstall: pyglet

    Found existing installation: pyglet 1.5.23

    Uninstalling pyglet-1.5.23:

      Successfully uninstalled pyglet-1.5.23

  Attempting uninstall: nes-py

    Found existing installation: nes-py 8.1.0

    Uninstalling nes-py-8.1.0:

      Successfully uninstalled nes-py-8.1.0

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.

modelarts 1.4.0 requires configparser<=5.0.2, but you have configparser 5.2.0 which is incompatible.

modelarts 1.4.0 requires tqdm<=4.62.3, but you have tqdm 4.63.0 which is incompatible.

Successfully installed gym-super-mario-bros-7.3.2 importlib-resources-5.4.0 nes-py-8.1.8 pyglet-1.5.11 tqdm-4.63.0

第2步:导入相关的库

import os
import shutil
import subprocess as sp
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as _mp
from torch.distributions import Categorical
import torch.multiprocessing as mp
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym.spaces import Box
from gym import Wrapper
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT, RIGHT_ONLY
import cv2
import matplotlib.pyplot as plt
from IPython import display

import moxing as mox
INFO:root:Using MoXing-v1.17.3-

INFO:root:Using OBS-Python-SDK-3.20.7

2. 训练参数初始化

该部分参数可以自己调整,以训练出更好的效果

opt={
    "world": 1,                # 可选大关:1,2,3,4,5,6,7,8
    "stage": 1,                # 可选小关:1,2,3,4 
    "action_type": "simple",   # 动作类别:"simple","right_only", "complex"
    'lr': 1e-4,                # 建议学习率:1e-3,1e-4, 1e-5,7e-5
    'gamma': 0.9,              # 奖励折扣
    'tau': 1.0,                # GAE参数
    'beta': 0.01,              # 熵系数
    'epsilon': 0.2,            # PPO的Clip系数
    'batch_size': 16,          # 经验回放的batch_size
    'max_episode':10,          # 最大训练局数
    'num_epochs': 10,          # 每条经验回放次数
    "num_local_steps": 512,    # 每局的最大步数
    "num_processes": 8,        # 训练进程数,一般等于训练机核心数
    "save_interval": 5,        # 每{}局保存一次模型
    "log_path": "./log",       # 日志保存路径
    "saved_path": "./model",   # 训练模型保存路径
    "pretrain_model": True,    # 是否加载预训练模型,目前只提供1-1关卡的预训练模型,其他需要从零开始训练
    "episode":5
}

3. 创建环境

结束标志:

  • 胜利:mario到达本关终点

  • 失败:mario受到敌人的伤害、坠入悬崖或者时间用完

奖励函数:

  • 得分:收集金币、踩扁敌人、结束时夺旗

  • 扣分:受到敌人伤害、掉落悬崖、结束时未夺旗

# 创建环境
def create_train_env(world, stage, actions, output_path=None):
    # 创建基础环境
    env = gym_super_mario_bros.make("SuperMarioBros-{}-{}-v0".format(world, stage))

    env = JoypadSpace(env, actions)
    # 对环境自定义
    env = CustomReward(env, world, stage, monitor=None)
    env = CustomSkipFrame(env)
    return env


# 对原始环境进行修改,以获得更好的训练效果
class CustomReward(Wrapper):
    def __init__(self, env=None, world=None, stage=None, monitor=None):
        super(CustomReward, self).__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(1, 84, 84))
        self.curr_score = 0
        self.current_x = 40
        self.world = world
        self.stage = stage
        if monitor:
            self.monitor = monitor
        else:
            self.monitor = None

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        if self.monitor:
            self.monitor.record(state)
        state = process_frame(state)
        reward += (info["score"] - self.curr_score) / 40.
        self.curr_score = info["score"]
        if done:
            if info["flag_get"]:
                reward += 50
            else:
                reward -= 50
        if self.world == 7 and self.stage == 4:
            if (506 <= info["x_pos"] <= 832 and info["y_pos"] > 127) or (
                    832 < info["x_pos"] <= 1064 and info["y_pos"] < 80) or (
                    1113 < info["x_pos"] <= 1464 and info["y_pos"] < 191) or (
                    1579 < info["x_pos"] <= 1943 and info["y_pos"] < 191) or (
                    1946 < info["x_pos"] <= 1964 and info["y_pos"] >= 191) or (
                    1984 < info["x_pos"] <= 2060 and (info["y_pos"] >= 191 or info["y_pos"] < 127)) or (
                    2114 < info["x_pos"] < 2440 and info["y_pos"] < 191) or info["x_pos"] < self.current_x - 500:
                reward -= 50
                done = True
        if self.world == 4 and self.stage == 4:
            if (info["x_pos"] <= 1500 and info["y_pos"] < 127) or (
                    1588 <= info["x_pos"] < 2380 and info["y_pos"] >= 127):
                reward = -50
                done = True

        self.current_x = info["x_pos"]
        return state, reward / 10., done, info

    def reset(self):
        self.curr_score = 0
        self.current_x = 40
        return process_frame(self.env.reset())


class MultipleEnvironments:
    def __init__(self, world, stage, action_type, num_envs, output_path=None):
        self.agent_conns, self.env_conns = zip(*[mp.Pipe() for _ in range(num_envs)])
        if action_type == "right_only":
            actions = RIGHT_ONLY
        elif action_type == "simple":
            actions = SIMPLE_MOVEMENT
        else:
            actions = COMPLEX_MOVEMENT
        self.envs = [create_train_env(world, stage, actions, output_path=output_path) for _ in range(num_envs)]
        self.num_states = self.envs[0].observation_space.shape[0]
        self.num_actions = len(actions)
        for index in range(num_envs):
            process = mp.Process(target=self.run, args=(index,))
            process.start()
            self.env_conns[index].close()

    def run(self, index):
        self.agent_conns[index].close()
        while True:
            request, action = self.env_conns[index].recv()
            if request == "step":
                self.env_conns[index].send(self.envs[index].step(action.item()))
            elif request == "reset":
                self.env_conns[index].send(self.envs[index].reset())
            else:
                raise NotImplementedError


def process_frame(frame):
    if frame is not None:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (84, 84))[None, :, :] / 255.
        return frame
    else:
        return np.zeros((1, 84, 84))
    

class CustomSkipFrame(Wrapper):
    def __init__(self, env, skip=4):
        super(CustomSkipFrame, self).__init__(env)
        self.observation_space = Box(low=0, high=255, shape=(skip, 84, 84))
        self.skip = skip
        self.states = np.zeros((skip, 84, 84), dtype=np.float32)

    def step(self, action):
        total_reward = 0
        last_states = []
        for i in range(self.skip):
            state, reward, done, info = self.env.step(action)
            total_reward += reward
            if i >= self.skip / 2:
                last_states.append(state)
            if done:
                self.reset()
                return self.states[None, :, :, :].astype(np.float32), total_reward, done, info
        max_state = np.max(np.concatenate(last_states, 0), 0)
        self.states[:-1] = self.states[1:]
        self.states[-1] = max_state
        return self.states[None, :, :, :].astype(np.float32), total_reward, done, info

    def reset(self):
        state = self.env.reset()
        self.states = np.concatenate([state for _ in range(self.skip)], 0)
        return self.states[None, :, :, :].astype(np.float32)

4. 定义神经网络

神经网络结构包含四层卷积网络和一层全连接网络,提取的特征输入critic层和actor层,分别输出value值和动作概率分布。

class Net(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1)
        self.linear = nn.Linear(32 * 6 * 6, 512)
        self.critic_linear = nn.Linear(512, 1)
        self.actor_linear = nn.Linear(512, num_actions)
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight, nn.init.calculate_gain('relu'))
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.linear(x.view(x.size(0), -1))
        return self.actor_linear(x), self.critic_linear(x)

5. 定义PPO算法

def evaluation(opt, global_model, num_states, num_actions,curr_episode):
    print('start evalution !')
    torch.manual_seed(123)
    if opt['action_type'] == "right":
        actions = RIGHT_ONLY
    elif opt['action_type'] == "simple":
        actions = SIMPLE_MOVEMENT
    else:
        actions = COMPLEX_MOVEMENT
    env = create_train_env(opt['world'], opt['stage'], actions)
    local_model = Net(num_states, num_actions)
    if torch.cuda.is_available():
        local_model.cuda()
    local_model.eval()
    state = torch.from_numpy(env.reset())
    if torch.cuda.is_available():
        state = state.cuda()
    
    plt.figure(figsize=(10,10))
    img = plt.imshow(env.render(mode='rgb_array'))
    
    done=False
    local_model.load_state_dict(global_model.state_dict()) #加载网络参数\

    while not done:
        if torch.cuda.is_available():
            state = state.cuda()
        logits, value = local_model(state)
        policy = F.softmax(logits, dim=1)
        action = torch.argmax(policy).item()
        state, reward, done, info = env.step(action)
        state = torch.from_numpy(state)
        
        img.set_data(env.render(mode='rgb_array')) # just update the data
        display.display(plt.gcf())
        display.clear_output(wait=True)

        if info["flag_get"]:
            print("flag getted in episode:{}!".format(curr_episode))
            torch.save(local_model.state_dict(),
                       "{}/ppo_super_mario_bros_{}_{}_{}".format(opt['saved_path'], opt['world'], opt['stage'],curr_episode))
            opt.update({'episode':curr_episode})
            env.close()
            return True
    return False

    
def train(opt):
    #判断cuda是否可用
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    if os.path.isdir(opt['log_path']):
        shutil.rmtree(opt['log_path'])

    os.makedirs(opt['log_path'])
    if not os.path.isdir(opt['saved_path']):
        os.makedirs(opt['saved_path'])
    mp = _mp.get_context("spawn")
    #创建环境
    envs = MultipleEnvironments(opt['world'], opt['stage'], opt['action_type'], opt['num_processes'])
    #创建模型
    model = Net(envs.num_states, envs.num_actions)
    if opt['pretrain_model']:
        print('加载预训练模型')
        if not os.path.exists("ppo_super_mario_bros_1_1_0"):
            mox.file.copy_parallel(
                "obs://modelarts-labs-bj4/course/modelarts/zjc_team/reinforcement_learning/ppo_mario/ppo_super_mario_bros_1_1_0",
                "ppo_super_mario_bros_1_1_0")
        if torch.cuda.is_available():
            model.load_state_dict(torch.load("ppo_super_mario_bros_1_1_0"))
            model.cuda()
        else:
            model.load_state_dict(torch.load("ppo_super_mario_bros_1_1_0",torch.device('cpu')))
    else:
         model.cuda()
    model.share_memory()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt['lr'])
    #环境重置
    [agent_conn.send(("reset", None)) for agent_conn in envs.agent_conns]
    #接收当前状态
    curr_states = [agent_conn.recv() for agent_conn in envs.agent_conns]
    curr_states = torch.from_numpy(np.concatenate(curr_states, 0))
    if torch.cuda.is_available():
        curr_states = curr_states.cuda()
    curr_episode = 0
    #在最大局数内训练
    while curr_episode<opt['max_episode']:
        if curr_episode % opt['save_interval'] == 0 and curr_episode > 0:
            torch.save(model.state_dict(),
                       "{}/ppo_super_mario_bros_{}_{}_{}".format(opt['saved_path'], opt['world'], opt['stage'], curr_episode))
        curr_episode += 1
        old_log_policies = []
        actions = []
        values = []
        states = []
        rewards = []
        dones = []
        #一局内最大步数
        for _ in range(opt['num_local_steps']):
            states.append(curr_states)
            logits, value = model(curr_states)
            values.append(value.squeeze())
            policy = F.softmax(logits, dim=1)
            old_m = Categorical(policy)
            action = old_m.sample()
            actions.append(action)
            old_log_policy = old_m.log_prob(action)
            old_log_policies.append(old_log_policy)
            #执行action
            if torch.cuda.is_available():
                [agent_conn.send(("step", act)) for agent_conn, act in zip(envs.agent_conns, action.cpu())]
            else:
                [agent_conn.send(("step", act)) for agent_conn, act in zip(envs.agent_conns, action)]
            state, reward, done, info = zip(*[agent_conn.recv() for agent_conn in envs.agent_conns])
            state = torch.from_numpy(np.concatenate(state, 0))
            if torch.cuda.is_available():
                state = state.cuda()
                reward = torch.cuda.FloatTensor(reward)
                done = torch.cuda.FloatTensor(done)
            else:
                reward = torch.FloatTensor(reward)
                done = torch.FloatTensor(done)
            rewards.append(reward)
            dones.append(done)
            curr_states = state

        _, next_value, = model(curr_states)
        next_value = next_value.squeeze()
        old_log_policies = torch.cat(old_log_policies).detach()
        actions = torch.cat(actions)
        values = torch.cat(values).detach()
        states = torch.cat(states)
        gae = 0
        R = []
        #gae计算
        for value, reward, done in list(zip(values, rewards, dones))[::-1]:
            gae = gae * opt['gamma'] * opt['tau']
            gae = gae + reward + opt['gamma'] * next_value.detach() * (1 - done) - value.detach()
            next_value = value
            R.append(gae + value)
        R = R[::-1]
        R = torch.cat(R).detach()
        advantages = R - values
        #策略更新
        for i in range(opt['num_epochs']):
            indice = torch.randperm(opt['num_local_steps'] * opt['num_processes'])
            for j in range(opt['batch_size']):
                batch_indices = indice[
                                int(j * (opt['num_local_steps'] * opt['num_processes'] / opt['batch_size'])): int((j + 1) * (
                                        opt['num_local_steps'] * opt['num_processes'] / opt['batch_size']))]
                logits, value = model(states[batch_indices])
                new_policy = F.softmax(logits, dim=1)
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(actions[batch_indices])
                ratio = torch.exp(new_log_policy - old_log_policies[batch_indices])
                actor_loss = -torch.mean(torch.min(ratio * advantages[batch_indices],
                                                   torch.clamp(ratio, 1.0 - opt['epsilon'], 1.0 + opt['epsilon']) *
                                                   advantages[
                                                       batch_indices]))
                critic_loss = F.smooth_l1_loss(R[batch_indices], value.squeeze())
                entropy_loss = torch.mean(new_m.entropy())
                #损失函数包含三个部分:actor损失,critic损失,和动作entropy损失
                total_loss = actor_loss + critic_loss - opt['beta'] * entropy_loss
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()
        print("Episode: {}. Total loss: {}".format(curr_episode, total_loss))
        
        finish=False
        for i in range(opt["num_processes"]):
            if info[i]["flag_get"]:
                finish=evaluation(opt, model,envs.num_states, envs.num_actions,curr_episode)
                if finish:
                    break
        if finish:
            break

6. 训练模型

训练10 Episode,耗时约5分钟

train(opt)
加载预训练模型

Episode: 1. Total loss: 1.1230244636535645

Episode: 2. Total loss: 2.553663730621338

Episode: 3. Total loss: 1.768389344215393

Episode: 4. Total loss: 1.6962862014770508

Episode: 5. Total loss: 1.0912611484527588

Episode: 6. Total loss: 1.6626232862472534

Episode: 7. Total loss: 1.9952025413513184

Episode: 8. Total loss: 1.2410558462142944

Episode: 9. Total loss: 1.3711413145065308

Episode: 10. Total loss: 1.2155205011367798

7. 使用模型推理游戏

定义推理函数

def infer(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    if opt['action_type'] == "right":
        actions = RIGHT_ONLY
    elif opt['action_type'] == "simple":
        actions = SIMPLE_MOVEMENT
    else:
        actions = COMPLEX_MOVEMENT
    env = create_train_env(opt['world'], opt['stage'], actions)
    model = Net(env.observation_space.shape[0], len(actions))
    if torch.cuda.is_available():
        model.load_state_dict(torch.load("{}/ppo_super_mario_bros_{}_{}_{}".format(opt['saved_path'],opt['world'], opt['stage'],opt['episode'])))
        model.cuda()
    else:
        model.load_state_dict(torch.load("{}/ppo_super_mario_bros_{}_{}_{}".format(opt['saved_path'], opt['world'], opt['stage'],opt['episode']),
                                         map_location=torch.device('cpu')))
    model.eval()
    state = torch.from_numpy(env.reset())
    
    plt.figure(figsize=(10,10))
    img = plt.imshow(env.render(mode='rgb_array'))
    
    while True:
        if torch.cuda.is_available():
            state = state.cuda()
        logits, value = model(state)
        policy = F.softmax(logits, dim=1)
        action = torch.argmax(policy).item()
        state, reward, done, info = env.step(action)
        state = torch.from_numpy(state)
        
        img.set_data(env.render(mode='rgb_array')) # just update the data
        display.display(plt.gcf())
        display.clear_output(wait=True)
        
        if info["flag_get"]:
            print("World {} stage {} completed".format(opt['world'], opt['stage']))
            break
            
        if done and info["flag_get"] is False:
            print('Game Failed')
            break
infer(opt)
World 1 stage 1 completed

8. 作业

  1. 请你调整步骤2中的训练参数,重新训练一个模型,使它在游戏中获得更好的表现

作者相关案例

使用PPO算法玩超级马里奥兄弟
发布于45个月以前
使用DQN算法玩2048游戏
发布于45个月以前
使用强化学习AlphaZero算法训练五子棋AI
发布于44个月以前
与中国象棋AI对战!
发布于43个月以前
工地场景钢筋盘点
发布于42个月以前

暂无数据

近7天热度

标签

  • GPU训练
  • 游戏
  • GameAI
  • 强化学习

热门案例推荐

使用PPO算法玩超级马里奥兄弟
ModelArts开发者 发布于45个月以前
使用DQN算法玩2048游戏
ModelArts开发者 发布于45个月以前
使用强化学习AlphaZero算法训练五子棋AI
ModelArts开发者 发布于44个月以前
与中国象棋AI对战!
ModelArts开发者 发布于43个月以前
安全帽检测
ModelArts 发布于44个月以前

暂无数据

评论0

3.7

3人已评价

  • 67%
  • 0%
  • 0%
  • 0%
  • 33%

登录后评论

  • ironicallylzl 22个月以前
    你可真牛逼。原封不动的搬别人的代码。
  • hid_gxexnpy8mjq6l81 36个月以前
    请问,在MultipleEnvironments类中, for index in range(num_envs): process = mp.Process(target=self.run, args=(index,)) process.start() self.env_conns[index].close() 这里process.start()会报错:handle is closed 应该怎么处理呢?
  • cloud 40个月以前
    这个不错
  • tom 40个月以前