之前把openpi的基础背景进行梳理, openpi论文及代码解析(A Vision-Language-Action Flow Model for General Robot Control) (一), 本篇文章将涉及代码进行整理. 这里推荐下RoboTwin, 这里的代码对RDT、PI0等模型进行修改和优化, 以下部分内容都是参照该代码进行的.
一、 数据整理
1. ALOHA数据集转LEROBOT数据集
仿照之前RDT方案以hdf5
的格式进行数据采集RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation - 如何训练自己数据集.之后我们需要基于脚本convert_aloha_data_to_lerobot.py 将我们采集的数据集hdf5
的数据集转换成lerobot
数据格式, 这里lerobot库版本一定要按照需要的版本进行安装. 原来脚本convert_aloha_data_to_lerobot.py 仅仅基于单个任务进行数据封装, 所以作者这里写的很奇怪, 因此我们需要修改部分代码, 从而使得模型支持多个任务训练. 修改的点如下:
-
LEROBOT_HOME
改为HF_LEROBOT_HOME
(新版本不支持LEROBOT_HOME
) -
motors
需要将left以及right进行对调, 保持和RDT一致都是先左后右 -
cameras
如果是三个相机的需要删除cam_low
- 为了保证每个任务和任务描述作绑定这里需要修改
populate_dataset
, 修改如下:
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
# add prompt
dir_path = os.path.dirname(ep_path)
json_Path = f"{dir_path}/instructions.json"
with open(json_Path, 'r') as f_instr:
instruction_dict = json.load(f_instr)
instructions = instruction_dict['instructions']
instruction = np.random.choice(instructions)
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
"task": instruction,
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode()
return dataset
整体代码如下: convert_aloha_data_to_lerobot_robot,
python3 convert_aloha_data_to_lerobot_robot --raw-dir xxx --repo_id xxx
如果数据集太大了, 可以调大以下参数, 对数据进行提速, 运行脚本, 即可.
HF_LEROBOT_HOME=你要保存数据路径 python3 convert_aloha_data_to_lerobot_robot --raw-dir xxx --repo_id xxx
image_writer_processes: int = 10
image_writer_threads: int = 5
运行完上述代码我们会在生成的数据路径中得到data
以及meta
两个文件夹
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
import json
import os
import fnmatch
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
]
cameras = [
"cam_high",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors), ),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors), ),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors), ),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors), ),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=50,
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
data = np.frombuffer(data, np.uint8)
# img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # 解码为彩色图像
imgs_array.append(cv2.imdecode(data, cv2.IMREAD_COLOR))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[
dict[str, np.ndarray],
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
# add prompt
dir_path = os.path.dirname(ep_path)
json_Path = f"{dir_path}/instructions.json"
with open(json_Path, 'r') as f_instr:
instruction_dict = json.load(f_instr)
instructions = instruction_dict['instructions']
instruction = np.random.choice(instructions)
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
"task": instruction,
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode()
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = False,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
# download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = []
for root, _, files in os.walk(raw_dir):
for filename in fnmatch.filter(files, '*.hdf5'):
file_path = os.path.join(root, filename)
hdf5_files.append(file_path)
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)
2. 配置文件设置
在config.py文件中的列表_CONFIGS加入自己模型的配置信息, 我的配置文件如下所示:
TrainConfig(
name="pi0_aloha_custom",
model=pi0.Pi0Config(max_token_len=90),
batch_size=64,
data=LeRobotAlohaDataConfig(
adapt_to_pi=False,
repo_id="R1-6-kinova-subdata",
assets=AssetsConfig(
assets_dir="/data/checkpoints_dirs/pi0/checkpoint/openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
),
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
"prompt": "prompt",
}
)
]
),
base_config=DataConfig(
local_files_only=True, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("/data/checkpoints_dirs/pi0/checkpoint/openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=20000,
fsdp_devices=1,
),
3. 计算数据信息
因为openpi
代码一直在维护, 所以很多代码需要进行修改, 此脚本用于计算给定配置的归一化统计数据。它将计算数据集中的数据的均值和标准差,并将其保存到配置资产目录。
运行脚本如下:
python3 compute_norm_stats.py --config-bane 你的配置名称
compute_norm_stats.py 代码修改如下所示:
"""Compute normalization statistics for a config.
This script is used to compute the normalization statistics for a given config. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config assets directory.
"""
import numpy as np
import tqdm
import tyro
import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms
class RemoveStrings(transforms.DataTransformFn):
def __call__(self, x: dict) -> dict:
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
def create_torch_dataloader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
model_config: _model.BaseModelConfig,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
if data_config.repo_id is None:
raise ValueError("Data config must have a repo_id")
"""Transform the dataset by applying the data transforms."""
dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config)
dataset = _data_loader.TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
RemoveStrings(),
],
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=8,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
def create_rlds_dataloader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False)
dataset = _data_loader.transform_iterable_dataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
RemoveStrings(),
],
is_batched=True,
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
else:
num_batches = len(dataset) // batch_size
data_loader = _data_loader.RLDSDataLoader(
dataset,
num_batches=num_batches,
)
return data_loader, num_batches
def main(config_name: str, max_frames: int | None = None):
config = _config.get_config(config_name)
data_config = config.data.create(config.assets_dirs, config.model)
output_path = config.assets_dirs / data_config.repo_id
if data_config.rlds_data_dir is not None:
data_loader, num_batches = create_rlds_dataloader(
data_config, config.model.action_horizon, config.batch_size, max_frames
)
else:
data_loader, num_batches = create_torch_dataloader(
data_config, config.model.action_horizon, config.batch_size, config.model, max_frames
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
for key in keys:
values = np.asarray(batch[key][0])
stats[key].update(values.reshape(-1, values.shape[-1]))
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
if __name__ == "__main__":
tyro.cli(main)
同时需要修改aloha_policy.py代码, 重新归一化文件, 代码如下所示:
import dataclasses
from typing import ClassVar
import einops
import numpy as np
from openpi import transforms
def make_aloha_example() -> dict:
"""Creates a random input example for the Aloha policy."""
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
# "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
@dataclasses.dataclass(frozen=True)
class AlohaInputs(transforms.DataTransformFn):
"""Inputs for the Aloha policy.
Expected inputs:
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
- state: [14]
- actions: [action_horizon, 14]
"""
# The action dimension of the model. Will be used to pad state and actions.
action_dim: int
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
# EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
# Get the state. We are padding from 14 to the model action dim.
state = transforms.pad_to_dim(data["state"], self.action_dim)
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": state,
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = transforms.pad_to_dim(actions, self.action_dim)
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class AlohaOutputs(transforms.DataTransformFn):
"""Outputs for the Aloha policy."""
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][:, :14])
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
def _joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def _normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def _unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def _gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
# value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
value = _unnormalize(value, min_val=0, max_val=1)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
# return _normalize(value, min_val=0.4, max_val=1.5)
return _normalize(value, min_val=0, max_val=1)
def _gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
# value = _unnormalize(value, min_val=0.4, max_val=1.5)
value = _unnormalize(value, min_val=0, max_val=1)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
# return _normalize(value, min_val=-0.6213, max_val=1.4910)
return _normalize(value, min_val=0, max_val=1)
def _gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
# value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
value = _unnormalize(value, min_val=0, max_val=1)
# return _normalize(value, min_val=0.4, max_val=1.5)
return _normalize(value, min_val=0, max_val=1)
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
# dim sizes: [6, 1, 6, 1]
state = np.asarray(data["state"])
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c")
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
data["state"] = state
return data
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
state = _joint_flip_mask() * state
# Reverse the gripper transformation that is being applied by the Aloha runtime.
state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
return state
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
actions = _joint_flip_mask() * actions
actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
actions = _joint_flip_mask() * actions
actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
return actions
我们先看下compute_norm_stats.py
这个文件:
-
data_config = config.data.create(config.assets_dirs, config.model)
, 我们可以得出data_config变量为:
DataConfig(repo_id='R1-6-kinova-subdata', asset_id='trossen', norm_stats={'actions': NormStats(mean=array([ 1.24345315e-04, 1.25149842e-02, 1.98457297e-02, 1.32796413e-03,
-6.48936862e-03, 7.16008944e-04, 4.99603391e-01, -7.04164326e-04,
1.77294109e-02, 1.83472950e-02, 1.46187330e-03, -5.42016001e-03,
3.64383188e-04, 4.55868423e-01, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]), std=array([0.1104502 , 0.17611068, 0.15441367, 0.12295121, 0.17716125,
0.16023925, 0.4079183 , 0.12189175, 0.1991071 , 0.17835201,
0.15127407, 0.19766295, 0.19989325, 0.39015493, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. ]), q01=None, q99=None), 'state': NormStats(mean=array([ 0.09093472, 0.47080797, -0.70882577, 0.09846087, 0.4463886 ,
-0.23006314, 0.49671665, -0.06535511, 0.27723944, -0.55815661,
-0.07980683, 0.51315635, 0.2389642 , 0.46351153, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. ]), std=array([0.29812354, 0.58756095, 0.5399825 , 0.28385323, 0.47698486,
0.45689079, 0.37958947, 0.31579304, 0.55893266, 0.52739483,
0.31751233, 0.48253328, 0.57435209, 0.35923567, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. ]), q01=None, q99=None)}, repack_transforms=Group(inputs=[RepackTransform(structure={'images': {'cam_high': 'observation.images.cam_high', 'cam_left_wrist': 'observation.images.cam_left_wrist', 'cam_right_wrist': 'observation.images.cam_right_wrist'}, 'state': 'observation.state', 'actions': 'action', 'prompt': 'prompt'})], outputs=()), data_transforms=Group(inputs=(AlohaInputs(action_dim=32, adapt_to_pi=False), DeltaActions(mask=(True, True, True, True, True, True, False, True, True, True, True, True, True, False))), outputs=(AbsoluteActions(mask=(True, True, True, True, True, True, False, True, True, True, True, True, True, False)), AlohaOutputs(adapt_to_pi=False))), model_transforms=Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x79935db7c310>)], outputs=()), use_quantile_norm=False, action_sequence_keys=('action',), prompt_from_task=True, local_files_only=True, rlds_data_dir=None, action_space=None)
注意上述参数加载的是原有模型权重对应的assets('pi0/checkpoint/openpi-assets/checkpoints/pi0_base/assets/trossen'
)
- 对于
create_torch_dataset
函数
a.dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
得到的dataset_meta, 如下:
LeRobotDatasetMetadata({
Repository ID: 'R1-6-kinova-subdata',
Total episodes: '2',
Total frames: '1550',
Features: '['observation.state', 'action', 'observation.velocity', 'observation.effort', 'observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']',
})',
camera_keys = {list: 3} ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']
chunks_size = {int} 1000
data_path = {str} 'data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet'
episodes = {dict: 2} {0: {'episode_index': 0, 'length': 1100, 'tasks': ['The left mechanical arm grips the mineral water bottle, and the right mechanical arm grips the glass. Both arms close in on t...hile the right arm doesn’t move. After pouring, the left arm returns to upright, and both arms reset to the initial positions.']}, 1: {'episode_index': 1, 'length': 450, 'tasks': ['Move only the left arm to fetch the bottle, leaving the right arm untouched.']}}
episodes_stats = {dict: 2} {0: {'action': {'count': [1100], 'max': [-1.34064519e+00 1.30165160e+00 -1.24819708e+00 3.10865164e-01, -1.54089659e-01 3.29327255e-01 1.00000000e+00 -1.26868689e+00, -4.59930956e-01 1.97489786e+00 -6.48532694e-09 1.00751770e+00, 4.32048678e-01 1.00000000e+00], 'mean': [-1.55226982 1.09708703 -1.48813653 0.01158946 -0.41880697 -0.18473192, 0.54909092 -1.47037899 -0.91339958 1.61369002 -0.03783819 0.54097348, 0.16207524 0.51727271], 'min': [-1.87110674 0.69813168 -1.91462624 -0.13498536 -0.57924187 -1.24234772, 0. -1.70791614 -1.17341757 1.14815688 -0.13945666 0.1764496, -0.09319337 0. ], 'std': [0.19098051 0.15774167 0.17283075 0.10409866 0.10842943 0.41489309, 0.49758443 0.13296363 0.21540029 0.19072002 0.03248322 0.26368281, 0.14758804 0.49970487]}, 'episode_index': {'count': [1100], 'max': [0], 'mean': [0.], 'min': [0], 'std': [0.]}, 'frame_index': {'count': [1100], 'max': [1099], 'mean': [549.5], 'min': [0], 'std': [317.54251684]}, 'index': {'coun...
features = {dict: 12} {'action': {'dtype': 'float32', 'names': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'shape': (14,)}, 'episode_index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'frame_index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'observation.effort': {'dtype': 'float32', 'names': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'shape': (14,)}, 'observation.images.cam_high': {'dtype': 'image', 'names': ['channels', 'height', 'width'], 'shape': (3, 480, 640)}, 'observation.images.cam_left_wrist': {'dtype': 'ima...
fps = {int} 25
image_keys = {list: 3} ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']
info = {dict: 13} {'chunks_size': 1000, 'codebase_version': 'v2.1', 'data_path': 'data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet', 'features': {'action': {'dtype': 'float32', 'names': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'shape': (14,)}, 'episode_index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'frame_index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'index': {'dtype': 'int64', 'names': None, 'shape': (1,)}, 'observation.effort': {'dtype': 'float32', 'names': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'shape': (14,)}, 'observation.images.cam...
names = {dict: 12} {'action': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'episode_index': None, 'frame_index': None, 'index': None, 'observation.effort': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elbow', 'right_forearm_roll', 'right_wrist_angle', 'right_wrist_rotate', 'right_gripper']], 'observation.images.cam_high': ['channels', 'height', 'width'], 'observation.images.cam_left_wrist': ['channels', 'height', 'width'], 'observation.images.cam_right_wrist': ['channels', 'height', 'width'], 'observation.state': [['left_waist', 'left_shoulder', 'left_elbow', 'left_forearm_roll', 'left_wrist_angle', 'left_wrist_rotate', 'left_gripper', 'right_waist', 'right_shoulder', 'right_elb...
repo_id = {str} 'R1-6-kinova-subdata'
revision = {str} 'v2.1'
robot_type = {str} 'aloha'
root = {PosixPath} PosixPath('/data/datasets/custom_data/lerobot/R1-6-kinova-subdata')
shapes = {dict: 12} {'action': (14,), 'episode_index': (1,), 'frame_index': (1,), 'index': (1,), 'observation.effort': (14,), 'observation.images.cam_high': (3, 480, 640), 'observation.images.cam_left_wrist': (3, 480, 640), 'observation.images.cam_right_wrist': (3, 480, 640), 'observation.state': (14,), 'observation.velocity': (14,), 'task_index': (1,), 'timestamp': (1,)}
stats = {dict: 12} {'action': {'count': [1550], 'max': [-1.34064519e+00 1.30165160e+00 -1.06901968e+00 3.10865164e-01, -1.54089659e-01 3.29327255e-01 1.00000000e+00 -1.26868689e+00, -4.59930956e-01 1.97489786e+00 -6.48532694e-09 1.00751770e+00, 4.32048678e-01 1.00000000e+00], 'mean': [-1.58315436 1.04104879 -1.46533966 0.01377588 -0.46073431 -0.15072884, 0.54064516 -1.50865436 -0.85090248 1.70105573 -0.02685291 0.45992332, 0.11502114 0.65741934], 'min': [-1.87110674 0.61033273 -1.91462624 -0.13498536 -0.79250985 -1.24234772, 0. -1.70791614 -1.17341757 1.14815688 -0.13945666 0.1764496, -0.09319337 0. ], 'std': [0.17688692 0.18742607 0.21843506 0.0913015 0.1437362 0.36109164, 0.49834541 0.12699503 0.20609423 0.21088331 0.03230805 0.25573573, 0.14446651 0.47457507]}, 'episode_index': {'count': [1550], 'max': [1], 'mean': [0.29032258], 'min': [0], 'std': [0.4539112]}, 'frame_index': {'count': [1550], 'max': [1099], 'mean': [455.14516129], 'min': [0], 'std': [313.40187777]...
task_to_task_index = {dict: 2} {'Move only the left arm to fetch the bottle, leaving the right arm untouched.': 1, 'The left mechanical arm grips the mineral water bottle, and the right mechanical arm grips the glass. Both arms close in on t...hile the right arm doesn’t move. After pouring, the left arm returns to upright, and both arms reset to the initial positions.': 0}
tasks = {dict: 2} {0: 'The left mechanical arm grips the mineral water bottle, and the right mechanical arm grips the glass. Both arms close in on t...hile the right arm doesn’t move. After pouring, the left arm returns to upright, and both arms reset to the initial positions.', 1: 'Move only the left arm to fetch the bottle, leaving the right arm untouched.'}
total_chunks = {int} 1
total_episodes = {int} 2
total_frames = {int} 1550
total_tasks = {int} 2
video_keys = {list: 0} []
video_path = {str} 'videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4'
- 接着对数据进行切分, 如下所示:
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
其中delta_timestamps=({'action': [0.0, 0.04, 0.08, 0.12, 0.16, 0.2, 0.24, 0.28, 0.32, 0.36, 0.4, 0.44, 0.48, 0.52, 0.56, 0.6, 0.64, 0.68, 0.72, 0.76, 0.8, 0.84, 0.88, 0.92, 0.96, 1.0, 1.04, 1.08, 1.12, 1.16, 1.2, 1.24, 1.28, 1.32, 1.36, 1.4, 1.44, 1.48, 1.52, 1.56, 1.6, 1.64, 1.68, 1.72, 1.76, 1.8, 1.84, 1.88, 1.92, 1.96]},)
根据数据集的 FPS,delta 时间戳会转换为索引,并在这些索引处查询数据。数据集会通过适当的填充来确保查询不会跨越 episode 边界`, 上述可以理解为因为horizon是50,所以序列是50帧, 当然我们的fps是25帧/秒, 所以t / dataset_meta.fps 对应的就是不同时间.,无论原始数据的采样率如何, 这种计算确保了动作序列的时间间隔与模型预期一致.
我们看下PromptFromLeRobotTask
@dataclasses.dataclass(frozen=True)
class PromptFromLeRobotTask(DataTransformFn):
"""Extracts a prompt from the current LeRobot dataset task."""
# Contains the LeRobot dataset tasks (dataset.meta.tasks).
tasks: dict[int, str]
def __call__(self, data: DataDict) -> DataDict:
if "task_index" not in data:
raise ValueError('Cannot extract prompt without "task_index"')
task_index = int(data["task_index"])
if (prompt := self.tasks.get(task_index)) is None:
raise ValueError(f"{task_index=} not found in task mapping: {self.tasks}")
return {**data, "prompt": prompt}
这里加入了prompt
信息给予data
, 得到单个data如下:
{'action': (shape 50, 14),
'action_is_pad': tensor([False*50),
'episode_index': tensor(0),
'frame_index': tensor(192),
'index': tensor(192),
'observation.effort': tensor([-1.8639, 1.1209, -1.5432, 0.2897, -0.2231, -0.5171, 0.4541, -1.2626,
-0.6135, 1.5777, -0.0918, 0.7735, 0.3882, 0.0087]),
'observation.images.cam_high': (shape 3* 480*14), 'observation.images.cam_left_wrist': (shape 3* 480*14),
'observation.images.cam_right_wrist': (shape 3* 480*14)
,'observation.state': tensor([-1.8639, 1.1209, -1.5432, 0.2897, -0.2231, -0.5171, 0.4541, -1.2626,
-0.6135, 1.5777, -0.0918, 0.7735, 0.3882, 0.0087]),
'observation.velocity': tensor([-1.8639, 1.1209, -1.5432, 0.2897, -0.2231, -0.5171, 0.4541, -1.2626, -0.6135, 1.5777, -0.0918, 0.7735, 0.3882, 0.0087]),
'prompt': 'The left mechanical arm grips the mineral water bottle, and the right mechanical arm grips the glass. Both arms close in on the center together until reaching a fitting position. The left gripper rotates to pour water, while the right arm doesn’t move. After pouring, the left arm returns to upright, and both arms reset to the initial positions.',
'task': 'The left mechanical arm grips the mineral water bottle, and the right mechanical arm grips the glass. Both arms close in on the center together until reaching a fitting position. The left gripper rotates to pour water, while the right arm doesn’t move. After pouring, the left arm returns to upright, and both arms reset to the initial positions.',
'task_index': tensor(0),
'timestamp': tensor(7.6800)}
之后会对数据作trasnformer操作, 如下所示:
dataset = _data_loader.TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
RemoveStrings(),
],
)
- 这里说明下
repack_transforms
的用处, 我们看下面的结果就知道了.
@dataclasses.dataclass(frozen=True)
class RepackTransform(DataTransformFn):
"""Repacks an input dictionary into a new dictionary.
Repacking is defined using a dictionary where the keys are the new keys and the values
are the flattened paths to the old keys. We use '/' as the separator during flattening.
Example:
{
"images": {
"cam_high": "observation.images.top",
"cam_low": "observation.images.bottom",
},
"state": "observation.state",
"actions": "action",
}
"""
structure: at.PyTree[str]
def __call__(self, data: DataDict) -> DataDict:
flat_item = flatten_dict(data)
return jax.tree.map(lambda k: flat_item[k], self.structure)
# self.structure {'actions': 'action', 'images': {'cam_high': 'observation.images.cam_high', 'cam_left_wrist': 'observation.images.cam_left_wrist', 'cam_right_wrist': 'observation.images.cam_right_wrist'}, 'prompt': 'prompt', 'state': 'observation.state'}
a) repack之前:
b) repack之后:
同时这里task以及prompt是字符串类型无需计算norm, 所以这里用
RemoveStrings
操作进行删除, 之后获取data_loader
并随着num_batches
进行返回.
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=16,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
最后我们看下下面的代码
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
for key in keys:
values = np.asarray(batch[key][0])
stats[key].update(values.reshape(-1, values.shape[-1]))
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
这里的stats
如下:
{'actions': <openpi.shared.normalize.RunningStats object at 0x78f8213465d0>,
'state': <openpi.shared.normalize.RunningStats object at 0x78f821345fd0>}
最终得到的norm信息如下:
二、 模型训练
训练脚本如下所示:
import dataclasses
import functools
import logging
import os
import platform
from typing import Any
import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
os.environ['HF_LEROBOT_HOME'] = '/data/datasets/custom_data/lerobot'
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
"""Loads and validates the weights. Returns a loaded subset of the weights."""
loaded_params = loader.load(params_shape)
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
return traverse_util.unflatten_dict(
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
)
@at.typecheck
def init_train_state(
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
# initialize the model (and its parameters).
model = config.model.create(model_rng)
# Merge the partial params into the model.
if partial_params is not None:
graphdef, state = nnx.split(model)
# This will produce an error if the partial params are not a subset of the state.
state.replace_by_pure_dict(partial_params)
model = nnx.merge(graphdef, state)
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
return training_utils.TrainState(
step=0,
params=params,
model_def=nnx.graphdef(model),
tx=tx,
opt_state=tx.init(params.filter(config.trainable_filter)),
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
train_state_shape = jax.eval_shape(init, init_rng)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# Initialize the train state and mix in the partial params.
train_state = jax.jit(
init,
donate_argnums=(1,), # donate the partial params buffer.
in_shardings=replicated_sharding,
out_shardings=state_sharding,
)(init_rng, partial_params)
return train_state, state_sharding
@at.typecheck
def train_step(
config: _config.TrainConfig,
rng: at.KeyArrayLike,
state: training_utils.TrainState,
batch: tuple[_model.Observation, _model.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
model = nnx.merge(state.model_def, state.params)
model.train()
@at.typecheck
def loss_fn(
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
):
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
return jnp.mean(chunked_loss)
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
# Filter out frozen params.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
# Update the model in place and return the new full state.
nnx.update(model, new_params)
new_params = nnx.state(model)
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = dataclasses.replace(
new_state,
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
),
)
# Filter out params that aren't kernels.
kernel_params = nnx.state(
model,
nnx.All(
nnx.Param,
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
lambda _, x: x.value.ndim > 1,
),
)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads),
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_period=config.keep_period,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
data_loader = _data_loader.create_data_loader(
config,
sharding=data_sharding,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
# Log images from first batch to sanity check.
images_to_log = [
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
for i in range(min(5, len(next(iter(batch[0].images.values())))))
]
wandb.log({"camera_views": images_to_log}, step=0)
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
functools.partial(train_step, config),
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
for step in pbar:
with sharding.set_mesh(mesh):
train_state, info = ptrain_step(train_rng, train_state, batch)
infos.append(info)
if step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())
1. 随机数生成
jax.random.PRNGKey 是 JAX 库中用于创建伪随机数生成器密钥(PRNG key)的函数。
PRNG key 是 JAX 中用于生成伪随机数序列的关键要素,通过分割(splitting)可以生成新的 PRNG key,确保生成的随机数序列是不相关的。 下面是一个例子
# 创建一个 PRNG key
rng_key = jax.random.PRNGKey(42)
# 生成 [0, 1) 范围内的随机数
random_numbers = jax.random.uniform(rng_key, shape=(3, 3))
print(random_numbers)
# 分割 PRNG key,生成新的 PRNG key
new_rng_key1, new_rng_key2 = jax.random.split(rng_key, num=2)
2. FSDP模型分片操作
mesh = sharding.make_mesh(config.fsdp_devices) # 默认config.fsdp_devices=1
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
(1) make_mesh 函数
BATCH_AXIS = "batch"
FSDP_AXIS = "fsdp"
# In FSDP, we shard the data across both the batch and FSDP axes.
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
#这部分代码确保可用的设备总数能被 FSDP 设备数整除,否则抛出错误。例如,若有 8 个设备,num_fsdp_devices 可以是 1、2、4、8,但不能是 3。
if jax.device_count() % num_fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
)
# 网格形状是一个二元组,第一个元素表示批次维度的设备数,第二个元素表示 FSDP 维度的设备数。例如,若有 8 个设备且 num_fsdp_devices=2,则网格形状为 (4, 2),表示 4 个批次设备组,每组 2 个设备用于 FSDP 分片
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS)) # 使用计算出的形状和指定的轴名称(BATCH_AXIS 和 FSDP_AXIS)创建网格。这两个轴名称是字符串常量,用于后续的张量分片规则。
make_mesh 函数的目的是根据指定的 FSDP(Fully Sharded Data Parallel)设备数量,创建一个二维的设备网格,用于并行训练。这个网格将设备划分为两个维度:批次维度(BATCH_AXIS)和 FSDP 分片维度(FSDP_AXIS)
(2) jax.sharding.NamedSharding
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
- 作用: 创建一个数据并行的分片策略,将张量沿 sharding.DATA_AXIS 维度分割到不同设备上。
- 详解:
NamedSharding:JAX 中基于命名轴的分片方式,通过 mesh 和 PartitionSpec 定义张量如何分布。
mesh:是你之前创建的二维设备网格 (BATCH_AXIS, FSDP_AXIS)。
PartitionSpec(sharding.DATA_AXIS):指定张量的第一个维度(假设 sharding.DATA_AXIS 对应 BATCH_AXIS)将被分割到 mesh 的 BATCH_AXIS 维度上。 - 示例
若 mesh 形状为 (4, 2)(4 个批次组,每组 2 个 FSDP 设备),且 sharding.DATA_AXIS 对应 BATCH_AXIS:
一个形状为 [8, 100] 的张量会被分割成 4 块(每块形状 [2, 100]),分别放在 BATCH_AXIS 维度的 4 个设备组上。每个设备组内的 2 个 FSDP 设备会复制这一块数据(因为未指定沿 FSDP_AXIS 分片)。
(3) NamedSharding
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
- 作用: 创建一个全复制的分片策略,将整个张量复制到所有设备上。
- 详解: PartitionSpec():空的分片规范,表示不分割任何维度。无论 mesh 形状如何,张量都会被完整复制到每个设备上。
- 示例: 对于同样的 (4, 2) 网格,一个形状为 [8, 100] 的张量会被完整复制到所有 8 个设备上。
3. 初始化checkpoint
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir, # openpi/scripts/checkpoints/{config_name}/{exp_name}')
keep_period=config.keep_period, # 5000, 没多少次保存一次
overwrite=config.overwrite, # True will overwrite the checkpoint directory if it already exists, 删除重新创建文件夹
resume=config.resume,
)
# 下面是具体的对应的代码
def initialize_checkpoint_dir(
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
) -> tuple[ocp.CheckpointManager, bool]:
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
resuming = False
if checkpoint_dir.exists():
if overwrite:
checkpoint_dir.rmtree()
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
elif resume:
resuming = True
else:
raise FileExistsError(
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
"to indicate how to handle it."
)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
mngr = ocp.CheckpointManager(
checkpoint_dir,
item_handlers={
"assets": CallbackHandler(),
"train_state": ocp.PyTreeCheckpointHandler(),
"params": ocp.PyTreeCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(
max_to_keep=1,
keep_period=keep_period, # 5000
create=False,
async_options=ocp.AsyncOptions(timeout_secs=7200),
),
)
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
# checkpoint, since it will fail.
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
resuming = False
return mngr, resuming
Orbax(JAX/TensorFlow 的检查点库)创建了一个 CheckpointManager
(1) CheckpointManager
Orbax 的核心类,负责:协调不同类型对象(如模型参数、优化器状态)的保存 / 加载。实现检查点的版本控制(如保留最近 n 个检查点)。支持异步操作以避免阻塞训练循环。
(2) 2. item_handlers
指定不同对象的保存 / 加载方式:
"assets": 使用 CallbackHandler(),适用于自定义对象(如字典、列表),需通过回调函数定义读写逻辑。
如下所示:
class CallbackHandler(ocp.AsyncCheckpointHandler):
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
def __init__(self):
self._executor = futures.ThreadPoolExecutor(max_workers=1)
def close(self):
self._executor.shutdown()
def save(self, directory: epath.Path, args: "CallbackSave"):
if jax.process_index() == 0:
args.callback(directory)
async def async_save(self, directory: epath.Path, args: "CallbackSave") -> list[futures.Future]:
return [self._executor.submit(self.save, directory, args)]
def restore(self, *args, **kwargs):
raise NotImplementedError("CallbackHandler does not support restore")
"train_state" 和 "params": 使用 PyTreeCheckpointHandler(),专门处理 JAX 的 PyTree 结构(如嵌套字典、数组)。
(3) CheckpointManagerOptions
配置检查点管理策略:
max_to_keep=1: 只保留最近的 1 个检查点,节省磁盘空间。
keep_period: 强制保留检查点的间隔(例如每 10 个检查点保留一次)。
create=False: 不自动创建检查点目录(需手动确保目录存在)。
async_options: 异步保存的超时设置(7200 秒 = 2 小时),避免长时间阻塞。
4. init_wandb
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
对应函数如下所示:
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name, mode="offline")
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
mode="offline"
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent) # 上传对应的代码
⭐️5. 创建数据dataloader
data_loader = _data_loader.create_data_loader(
config,
sharding=data_sharding, # NamedSharding(mesh=Mesh('batch': 8, 'fsdp': 1), spec=PartitionSpec(('batch', 'fsdp'),), memory_kind=device)
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
我们先看下对应的create_data_loader
函数
def create_data_loader(
config: _config.TrainConfig,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
skip_norm_stats: bool = False,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training."""
data_config = config.data.create(config.assets_dirs, config.model)
if data_config.rlds_data_dir is not None:
return create_rlds_data_loader(
data_config,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
skip_norm_stats=skip_norm_stats,
)
return create_torch_data_loader(
data_config,
model_config=config.model,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=config.num_workers,
seed=config.seed,
skip_norm_stats=skip_norm_stats,
)
该函数先进入create
函数生成data_config
, 我们看下create
函数
# config.py ->class LeRobotAlohaDataConfig(DataConfigFactory)
@dataclasses.dataclass(frozen=True)
class LeRobotAlohaDataConfig(DataConfigFactory):
# If true, will convert joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions: bool = True
# If provided, will be injected into the input data if the "prompt" key is not present.
default_prompt: str | None = None
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model. People who
# use standard Aloha data should set this to true.
adapt_to_pi: bool = True
# Repack transforms.
repack_transforms: tyro.conf.Suppress[_transforms.Group] = dataclasses.field(
default=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {"cam_high": "observation.images.top"},
"state": "observation.state",
"actions": "action",
}
)
]
)
)
# Action keys that will be used to read the action sequence from the dataset.
action_sequence_keys: Sequence[str] = ("action",)
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
data_transforms = _transforms.Group(
inputs=[aloha_policy.AlohaInputs(action_dim=model_config.action_dim, adapt_to_pi=self.adapt_to_pi)],
outputs=[aloha_policy.AlohaOutputs(adapt_to_pi=self.adapt_to_pi)],
)
if self.use_delta_joint_actions:
delta_action_mask = _transforms.make_bool_mask(6, -1, 6, -1) # (True, True, True, True, True, True, False, True, True, True, True, True, True, False)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
model_transforms = ModelTransformFactory(default_prompt=self.default_prompt)(model_config) #delta action mask (True, True, True, True, True, True, False, True, True, True, True, True, True, False)
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=self.repack_transforms,
data_transforms=data_transforms,
model_transforms=model_transforms,
action_sequence_keys=self.action_sequence_keys,
)
我们可以看下data_transforms
, 如下图所示
我们在看下
ModelTransformFactory
@dataclasses.dataclass(frozen=True)
class ModelTransformFactory(GroupFactory):
"""Creates model transforms for standard pi0 models."""
# If provided, will determine the default prompt that be used by the model.
default_prompt: str | None = None
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
match model_config.model_type:
case _model.ModelType.PI0:
return _transforms.Group(
inputs=[
_transforms.InjectDefaultPrompt(self.default_prompt),
_transforms.ResizeImages(224, 224),
_transforms.TokenizePrompt(
_tokenizer.PaligemmaTokenizer(model_config.max_token_len),
),
],
)
case _model.ModelType.PI0_FAST:
return _transforms.Group(
inputs=[
_transforms.InjectDefaultPrompt(self.default_prompt),
_transforms.ResizeImages(224, 224),
_transforms.TokenizeFASTInputs(
_tokenizer.FASTTokenizer(model_config.max_token_len),
),
],
outputs=[
_transforms.ExtractFASTActions(
_tokenizer.FASTTokenizer(model_config.max_token_len),
action_horizon=model_config.action_horizon,
action_dim=model_config.action_dim,
)
],
)
我们使用的是PI0所以我们可以看到这里设置针对图像需要缩放大小是224*224
, 并设置PaligemmaTokenizer的max_token_len
. 最终得到的model_transforms
如下所示:
Group(inputs=[InjectDefaultPrompt(prompt=None), ResizeImages(height=224, width=224), TokenizePrompt(tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x75d7d978cc90>)], outputs=())
并经过dataclass.replace
操作, 对DataConfig进行修改得到如下信息
之后将上述的config文件传入到create_torch_data_loader
中. 我们看下传入的变量:
- data_config:
DataConfig(
repo_id='R1-6-kinova-subdata',
asset_id='trossen',
norm_stats={
'actions': NormStats(mean=array([shape=32]), std=array([shape=32]), q01=None, q99=None),
'state': NormStats(mean=array([shape=32]), std=array([shape=32]), q01=None, q99=None)
},
repack_transforms=Group(
inputs=[RepackTransform(
structure={
'images': {
'cam_high':
'observation.images.cam_high',
'cam_left_wrist': 'observation.images.cam_left_wrist',
'cam_right_wrist':'observation.images.cam_right_wrist'},
'state': 'observation.state',
'actions': 'action',
'prompt': 'prompt'
}
)
],
outputs=()
),
data_transforms=Group(
inputs=(
AlohaInputs(action_dim=32, adapt_to_pi=False),
DeltaActions(
mask=(True, True, True, True, True, True, False, True, True, True, True, True, True, False)
)
),
outputs=(
AbsoluteActions(
mask=(True, True, True, True, True, True, False, True, True, True, True, True, True, False)
),
AlohaOutputs(adapt_to_pi=False)
)
),
model_transforms=Group(
inputs=[
InjectDefaultPrompt(prompt=None),
ResizeImages(height=224, width=224),
TokenizePrompt(
tokenizer=<openpi.models.tokenizer.PaligemmaTokenizer object at 0x750d71b2c710>
)
], outputs=()),
use_quantile_norm=False,
action_sequence_keys=('action',),
prompt_from_task=True,
local_files_only=True,
rlds_data_dir=None,
action_space=None
)
- model_config:
Pi0Config(
action_dim=32,
action_horizon=50,
max_token_len=90,
dtype='bfloat16',
paligemma_variant='gemma_2b',
action_expert_variant='gemma_300m'
)
- 其他参数
batch_size: 64,
sharding: NamedSharding(
mesh=Mesh('batch': 8, 'fsdp': 1),
spec=PartitionSpec(
('batch', 'fsdp'),
),
memory_kind=device
),
shuffle: True,
num_batches: None,
num_workers: 16,
seed:32,
skip_norm_stats=False
下面继续看create_torch_data_loader
函数
def create_torch_data_loader(
data_config: _config.DataConfig,
model_config: _model.BaseModelConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
dataset = create_torch_dataset(data_config, action_horizon, model_config)
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
data_loader = TorchDataLoader(
dataset,
local_batch_size=batch_size // jax.process_count(),
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=num_workers,
seed=seed,
)
return DataLoaderImpl(data_config, data_loader)
首先看下create_torch_dataset
函数
def create_torch_dataset(
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
) -> Dataset:
"""Create a dataset for training."""
repo_id = data_config.repo_id
if repo_id is None:
raise ValueError("Repo ID is not set. Cannot create dataset.")
if repo_id == "fake":
return FakeDataset(model_config, num_samples=1024)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
if data_config.prompt_from_task:
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
return dataset
a. 通过 repo_id
得到dataset_meta
, 得到的dataset_meta如下:
同时, 下面代码会基于fps以及horizon进行数据切分.
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
并得到对应的dataset
,
对应的transformers是
其对应的类别是, 其会对数据作transfom变换, 再进行返回:
我们可以看下dataset[0]对应的值,
之后数据会进入transform_dataset
函数中, 该函数加入更多的数据transform
操作
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
if data_config.repo_id != "fake" and not skip_norm_stats:
if data_config.norm_stats is None:
raise ValueError(
"Normalization stats not found. "
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
)
norm_stats = data_config.norm_stats
return TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
)
@dataclasses.dataclass(frozen=True)
class CompositeTransform(DataTransformFn):
"""A composite transform that applies a sequence of transforms in order."""
transforms: Sequence[DataTransformFn]
def __call__(self, data: DataDict) -> DataDict:
for transform in self.transforms:
data = transform(data)
return data
我们看下之后的操作有哪些?可以看下图所示, 后面会一步步进行分解
-
RepackTransform
前面已经讲过了, 就是将.转成结构化数据 -
AlohaInputs
: 我们看下面的代码:
@dataclasses.dataclass(frozen=True)
class AlohaInputs(transforms.DataTransformFn):
"""Inputs for the Aloha policy.
Expected inputs:
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
- state: [14]
- actions: [action_horizon, 14]
"""
# The action dimension of the model. Will be used to pad state and actions.
action_dim: int
# If true, this will convert the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi: bool = True
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
# replaced with black images and the corresponding `image_mask` will be set to False.
# EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_left_wrist", "cam_right_wrist")
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
# Get the state. We are padding from 14 to the model action dim.
state = transforms.pad_to_dim(data["state"], self.action_dim)
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.True_,
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.True_
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.False_
inputs = {
"image": images,
"image_mask": image_masks,
"state": state,
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"]) # (50, 14)
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
inputs["actions"] = transforms.pad_to_dim(actions, self.action_dim)
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
我们需要先看下_decode_aloha函数
, 如下所示:
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
# dim sizes: [6, 1, 6, 1]
state = np.asarray(data["state"])
state = _decode_state(state, adapt_to_pi=adapt_to_pi) # 这里adapt_to_pi=False, 所以这里state没有变化
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [channel, height, width] to [height, width, channel].
return einops.rearrange(img, "c h w -> h w c") # 转成 h, w, c
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
data["state"] = state
return data
self.action_dim
这里默认是32, 就是把我们的14个维度的特征转成32的空间维度, 其他用0补充.
这里相当于分别对action
以及state
的14个维度转成了32个维度. 并把拥有的图片的mask设置为True, 同时把对应的prompt获取.
DeltaActions
@dataclasses.dataclass(frozen=True)
class DeltaActions(DataTransformFn):
"""Repacks absolute actions into delta action space."""
# Boolean mask for the action dimensions to be repacked into delta action space. Length
# can be smaller than the actual number of dimensions. If None, this transform is a no-op.
# See `make_bool_mask` for more details.
mask: Sequence[bool] | None
def __call__(self, data: DataDict) -> DataDict:
if "actions" not in data or self.mask is None:
return data
state, actions = data["state"], data["actions"]
mask = np.asarray(self.mask) # [True, True, True, True, True, True, False, True, True, True, True, True, True, False]
dims = mask.shape[-1] # 14
actions[..., :dims] -= np.expand_dims(np.where(mask, state[..., :dims], 0), axis=-2) # 对应的第7个与第14设置为0, 这里设置dim是为了对于前14个位置进行相减
data["actions"] = actions
return data
这个类实现了一种动作空间转换,将绝对动作值转换为相对于当前状态的增量值。在强化学习中,这种转换常用于让智能体学习状态变化量而非绝对数值,有助于提高学习稳定性。其实就是action-state
, 相对当前状态, action有什么的增量.
-
Normalize
对以下所有数据进行归一化
@dataclasses.dataclass(frozen=True)
class Normalize(DataTransformFn):
norm_stats: at.PyTree[NormStats] | None
# If true, will use quantile normalization. Otherwise, normal z-score normalization will be used.
use_quantiles: bool = False
# If true, will raise an error if any of the keys in the norm stats are not present in the data.
strict: bool = False
def __post_init__(self):
if self.norm_stats is not None and self.use_quantiles:
_assert_quantile_stats(self.norm_stats)
def __call__(self, data: DataDict) -> DataDict:
if self.norm_stats is None:
return data
return apply_tree(
data,
self.norm_stats,
self._normalize_quantile if self.use_quantiles else self._normalize, # use_quantile = False
strict=self.strict,
)
def _normalize(self, x, stats: NormStats):
return (x - stats.mean) / (stats.std + 1e-6)
def _normalize_quantile(self, x, stats: NormStats):
assert stats.q01 is not None
assert stats.q99 is not None
return (x - stats.q01) / (stats.q99 - stats.q01 + 1e-6) * 2.0 - 1.0
def apply_tree(
tree: at.PyTree[T], selector: at.PyTree[S], fn: Callable[[T, S], T], *, strict: bool = False
) -> at.PyTree[T]:
tree = flatten_dict(tree)
selector = flatten_dict(selector)
def transform(k: str, v: T) -> T:
if k in selector:
return fn(v, selector[k])
return v
if strict:
for k in selector:
if k not in tree:
raise ValueError(f"Selector key {k} not found in tree")
return unflatten_dict({k: transform(k, v) for k, v in tree.items()})
-
InjectDefaultPrompt
下面则是将prompt转成numpy
@dataclasses.dataclass(frozen=True)
class InjectDefaultPrompt(DataTransformFn):
prompt: str | None
def __call__(self, data: DataDict) -> DataDict:
if self.prompt is not None and "prompt" not in data: # 这里我们的prompt为None 则直接跳过
data["prompt"] = np.asarray(self.prompt)
return data
-
ResizeImages
采用插值的方式对图片进行resize
@dataclasses.dataclass(frozen=True)
class ResizeImages(DataTransformFn):
height: int
width: int
def __call__(self, data: DataDict) -> DataDict:
data["image"] = {k: image_tools.resize_with_pad(v, self.height, self.width) for k, v in data["image"].items()}
return data
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
Args:
images: A batch of images in [..., height, width, channel] format.
height: The target height of the image.
width: The target width of the image.
method: The interpolation method to use. Default is bilinear.
Returns:
The resized images in [..., height, width, channel].
"""
# If the images are already the correct size, return them as is.
if images.shape[-3:-1] == (height, width):
return images
original_shape = images.shape
images = images.reshape(-1, *original_shape[-3:])
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
-
TokenizePrompt
对prompt tokenize化
@dataclasses.dataclass(frozen=True)
class TokenizePrompt(DataTransformFn):
tokenizer: _tokenizer.PaligemmaTokenizer
def __call__(self, data: DataDict) -> DataDict:
if (prompt := data.pop("prompt", None)) is None:
raise ValueError("Prompt is required")
if not isinstance(prompt, str):
prompt = prompt.item()
tokens, token_masks = self.tokenizer.tokenize(prompt)
return {**data, "tokenized_prompt": tokens, "tokenized_prompt_mask": token_masks}
最后我们看下TorchDataLoader
类
class TorchDataLoader:
def __init__(
self,
dataset,
local_batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
):
"""Create a PyTorch data loader.
Args:
dataset: The dataset to load.
local_batch_size: The local batch size for each process.
sharding: The sharding to use for the data loader.
shuffle: Whether to shuffle the data.
num_batches: If provided, determines the number of returned batches. If the
number is larger than the number of batches in the dataset, the data loader
will loop over the dataset. If not provided, will iterate over the dataset
indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
if jax.process_count() > 1:
raise NotImplementedError("Data loading with multiple processes is not supported.")
if len(dataset) < local_batch_size:
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
if sharding is None:
# Use data parallel sharding by default.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._sharding = sharding
self._num_batches = num_batches
mp_context = None
if num_workers > 0:
mp_context = multiprocessing.get_context("spawn")
generator = torch.Generator()
generator.manual_seed(seed)
self._data_loader = torch.utils.data.DataLoader(
typing.cast(torch.utils.data.Dataset, dataset),
batch_size=local_batch_size,
shuffle=shuffle,
num_workers=num_workers,
multiprocessing_context=mp_context,
persistent_workers=num_workers > 0,
collate_fn=_collate_fn,
worker_init_fn=_worker_init_fn,
drop_last=True,
generator=generator,
)
@property
def torch_loader(self) -> torch.utils.data.DataLoader:
return self._data_loader
def __iter__(self):
num_items = 0
while True:
data_iter = iter(self._data_loader)
while True:
if self._num_batches is not None and num_items >= self._num_batches:
return
try:
batch = next(data_iter)
except StopIteration:
break # We've exhausted the dataset. Create a new iterator and start over.
num_items += 1
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
def _collate_fn(items):
"""Collate the batch elements into batched numpy arrays."""
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
# may be JAX arrays.
return jax.tree.map(lambda *x: np.stack(np.asarray(x), axis=0), *items)
def _worker_init_fn(worker_id: int) -> None:
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
# NOTE: This is called after jax is imported inside the worker process. This
# means that this approach will not work for selecting the backend.
# 此代码能够关闭 JAX 运行时预分配几乎全部 GPU/TPU 内存的默认行为。开启预分配(默认是开启的)时,JAX 会预留大部分设备内存,这虽然能减少内存碎片,但可能会使其他 GPU 应用程序无法正常运行。将其设置为"false"之后,JAX 会根据实际需求来动态分配内存。
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# 这行代码把内存分配器切换成了平台默认的分配器,而非 JAX 自己的缓存分配器。平台分配器可以更好地和其他 GPU 应用程序共享内存,不过在某些情形下,内存碎片问题可能会变得更严重。
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
class DataLoaderImpl(DataLoader):
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
self._data_config = data_config
self._data_loader = data_loader
def data_config(self) -> _config.DataConfig:
return self._data_config
def __iter__(self):
for batch in self._data_loader:
yield _model.Observation.from_dict(batch), batch["actions"]
5. 日志配置
我们看到logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
打印出来, 如下所示:
[0].images['base_0_rgb']: (64, 224, 224, 3)@float32
[0].images['left_wrist_0_rgb']: (64, 224, 224, 3)@float32
[0].images['right_wrist_0_rgb']: (64, 224, 224, 3)@float32
[0].image_masks['base_0_rgb']: (64,)@bool
[0].image_masks['left_wrist_0_rgb']: (64,)@bool
[0].image_masks['right_wrist_0_rgb']: (64,)@bool
[0].state: (64, 32)@float32
[0].tokenized_prompt: (64, 90)@int32
[0].tokenized_prompt_mask: (64, 90)@bool
[1]: (64, 50, 32)@float32 (2828438:train.py:231)
下面的代码也会将batch的第一个图片进行合并并通过wandb
进行记录.
images_to_log = [
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
for i in range(min(5, len(next(iter(batch[0].images.values())))))
]
wandb.log({"camera_views": images_to_log}, step=0)
6. 初始化state变量
mesh: Mesh(device_ids=array([[0], \n [1], \n[2],\n [3],\n [4],\n [5],\n [6],\n [7]]),axis_names=('batch', 'fsdp'), axis_types={Hidden: ('batch', 'fsdp')})
init_rng: Array((), dtype=key<fry>) overlaying:[ 64467757 2916123636]
resuming: False
将上述变量输入到函数
init_train_state
函数中去 函数如下所示:
@at.typecheck
def init_train_state(
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
# initialize the model (and its parameters).
model = config.model.create(model_rng)
# Merge the partial params into the model.
if partial_params is not None:
graphdef, state = nnx.split(model)
# This will produce an error if the partial params are not a subset of the state.
state.replace_by_pure_dict(partial_params)
model = nnx.merge(graphdef, state)
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
return training_utils.TrainState(
step=0,
params=params,
model_def=nnx.graphdef(model),
tx=tx,
opt_state=tx.init(params.filter(config.trainable_filter)),
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
train_state_shape = jax.eval_shape(init, init_rng) # jax.eval_shape(fun, *args, **kwargs)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# Initialize the train state and mix in the partial params.
train_state = jax.jit(
init,
donate_argnums=(1,), # donate the partial params buffer.
in_shardings=replicated_sharding,
out_shardings=state_sharding,
)(init_rng, partial_params)
return train_state, state_sharding
(1) 创建优化器optimizer
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
def create_optimizer(
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
) -> optax.GradientTransformation:
lr = lr_schedule.create()
return optimizer.create(lr, weight_decay_mask=weight_decay_mask) # weight_decay_mask为None
这里lr
选用CosineDecaySchedule
, 如下所示:
@dataclasses.dataclass(frozen=True)
class CosineDecaySchedule(LRScheduleConfig):
"""Cosine decay schedule with warmup."""
warmup_steps: int = 1_000
peak_lr: float = 2.5e-5
decay_steps: int = 30_000
decay_lr: float = 2.5e-6
def create(self) -> optax.Schedule:
return optax.warmup_cosine_decay_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
peak_value=self.peak_lr,
warmup_steps=self.warmup_steps,
decay_steps=self.decay_steps,
end_value=self.decay_lr,
)