mip-NeRF代码debug

人工智能135

代码:https://github.com/google/mipnerf
翻译解说:https://blog.csdn.net/qq_43620967/article/details/124458976

mip-NeRF-READNME

该存储库包含以下内容的代码版本 Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. 这个实现是用JAX写的,是Google的JaxNeRF实现的一个分支。如果您遇到任何问题,请联系Jon Barron
mip-NeRF代码debug

; Abstract

mip-NeRF代码debug

Installation

我们建议使用 Anaconda 来建立环境。运行以下命令:


git clone https://github.com/google/mipnerf.git; cd mipnerf

conda create --name mipnerf python=3.6.13; conda activate mipnerf

conda install pip; pip install --upgrade pip

pip install -r requirements.txt

pip install -r requirements.txt会报错
mip-NeRF代码debug
改成

pip3 install -r requirements.txt

即可
mip-NeRF代码debug
若报 超时的错误

pip3 install -r requirements.txt --default-timeout=500

[Optional] Install GPU and TPU support for Jax


pip install --upgrade jax jaxlib==0.1.65+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

我安装的是

pip3 install --upgrade jax==0.2.3 jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Data

然后,你需要从NeRF官方Google Drive下载数据集。请下载并解压 nerf_synthetic.zipnerf_llff_data.zip.

Generate multiscale dataset

您可以通过运行以下命令来生成本文中使用的多尺度数据集,

python scripts/convert_blender_data.py --blenderdir /nerf_synthetic --outdir /multiscale

Running

在本文中使用的三个数据集的单个场景上训练mip-NeRF的示例脚本可以在 scripts/中找到。您需要 将路径更改为指向数据集所在的位置。我们的模型和一些消融的Gin配置文件可以在 configs/中找到。在 scripts/中可以找到对每个场景的测试集进行评估的示例脚本,之后您可以使用 scripts/summary . ipynb来生成所有场景的 错误度量,其格式与本文表格中使用的格式相同。

OOM errors

您可能需要减小 批处理大小,以避免内存不足错误。例如,该模型可以使用以下标志在NVIDIA 3080 (10Gb)上运行.

--gin_param="Config.batch_size = 1024"

Citation

If you use this software package, please cite our paper:

@misc{barron2021mipnerf,
title={Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields},
author={Jonathan T. Barron and Ben Mildenhall and Matthew Tancik and Peter Hedman and Ricardo Martin-Brualla and Pratul P. Srinivasan},
year={2021},
eprint={2103.13415},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Acknowledgements

Thanks to Boyang Deng for JaxNeRF.

项目地址

/media/hp/C440AC3040AC2ADE/ysh/mipnerf-main/

scripts/convert_blender_data.py

def main

def main(unused_argv):

  blenderdir = FLAGS.blenderdir
  outdir = FLAGS.outdir
  n_down = FLAGS.n_down
  if not os.path.exists(outdir):
    os.makedirs(outdir)

  dirs = [os.path.join(blenderdir, f) for f in os.listdir(blenderdir)]
  dirs = [d for d in dirs if os.path.isdir(d)]
  print(dirs)
  for basedir in dirs:
    print()
    newdir = os.path.join(outdir, os.path.basename(basedir))
    print('Converting from', basedir, 'to', newdir)
    convert_to_nerfdata(basedir, newdir, n_down)

def load_renderings

def load_renderings(data_dir, split):
  """Load images and metadata from disk."""
  f = 'transforms_{}.json'.format(split)
  with open(path.join(data_dir, f), 'r') as fp:
    meta = json.load(fp)
  images = []
  cams = []
  print('Loading imgs')
  for frame in meta['frames']:
    fname = os.path.join(data_dir, frame['file_path'] + '.png')
    with open(fname, 'rb') as imgin:
      image = np.array(Image.open(imgin), dtype=np.float32) / 255.

    cams.append(frame['transform_matrix'])
    images.append(image)
  ret = {}
  ret['images'] = np.stack(images, axis=0)
  print('Loaded all images, shape is', ret['images'].shape)
  ret['camtoworlds'] = np.stack(cams, axis=0)
  w = ret['images'].shape[2]
  camera_angle_x = float(meta['camera_angle_x'])
  ret['focal'] = .5 * w / np.tan(.5 * camera_angle_x)
  return ret

transforms_train.json

mip-NeRF代码debug
mip-NeRF代码debug
camera_angle_x 用来计算focal ,是 相机的水平视场

ret['focal'] = .5 * w / np.tan(.5 * camera_angle_x)

file_path 用来获取图片路径

fname = os.path.join(data_dir, frame['file_path'] + '.png')

rotation 没用到

transform_matrix 用作相机坐标到世界坐标的转换
是从相机坐标到世界坐标转换的 姿态矩阵(camera-to-world) 即
mip-NeRF代码debug
相机坐标系=Cex * 世界坐标系
Cex逆 * 相机坐标系=世界坐标系mip-NeRF代码debug

cams.append(frame['transform_matrix'])
ret['camtoworlds'] = np.stack(cams, axis=0)

mip-NeRF代码debug

def convert_to_nerfdata

def convert_to_nerfdata(basedir, newdir, n_down):
  """Convert Blender data to multiscale."""
  if not os.path.exists(newdir):
    os.makedirs(newdir)
  splits = ['train', 'val', 'test']
  bigmeta = {}

  for split in splits:
    print('Split', split)

    data = load_renderings(basedir, split)

    imgdir = 'images_{}'.format(split)
    os.makedirs(os.path.join(newdir, imgdir), exist_ok=True)
    fnames = []
    widths = []
    heights = []
    focals = []
    cam2worlds = []
    lossmults = []
    labels = []
    nears, fars = [], []
    f = data['focal']
    print('Saving images')
    for i, img in enumerate(data['images']):
      for j in range(n_down):
        fname = '{}/{:03d}_d{}.png'.format(imgdir, i, j)
        fnames.append(fname)
        fname = os.path.join(newdir, fname)
        with open(fname, 'wb') as imgout:
          img8 = Image.fromarray(np.uint8(img * 255))
          img8.save(imgout)
        widths.append(img.shape[1])
        heights.append(img.shape[0])
        focals.append(f / 2**j)
        cam2worlds.append(data['camtoworlds'][i].tolist())
        lossmults.append(4.**j)
        labels.append(j)
        nears.append(2.)
        fars.append(6.)
        img = down2(img)

mip-NeRF代码debug
def down2

def down2(img):
  sh = img.shape
  return np.mean(np.reshape(img, [sh[0] // 2, 2, sh[1] // 2, 2, -1]), (1, 3))

图片大小一路缩小
sh变化:

(800, 800, 4)->(100, 100, 4)

shape 值变化:

[800, 400, 200, 100]

    meta = {}
    meta['file_path'] = fnames
    meta['cam2world'] = cam2worlds
    meta['width'] = widths
    meta['height'] = heights
    meta['focal'] = focals
    meta['label'] = labels
    meta['near'] = nears
    meta['far'] = fars
    meta['lossmult'] = lossmults

    fx = np.array(focals)
    fy = np.array(focals)
    cx = np.array(meta['width']) * .5
    cy = np.array(meta['height']) * .5
    arr0 = np.zeros_like(cx)
    arr1 = np.ones_like(cx)
    k_inv = np.array([
        [arr1 / fx, arr0, -cx / fx],
        [arr0, -arr1 / fy, cy / fy],
        [arr0, arr0, -arr1],
    ])
    k_inv = np.moveaxis(k_inv, -1, 0)
    meta['pix2cam'] = k_inv.tolist()

    bigmeta[split] = meta

  for k in bigmeta:
    for j in bigmeta[k]:
      print(k, j, type(bigmeta[k][j]), np.array(bigmeta[k][j]).shape)

  jsonfile = os.path.join(newdir, 'metadata.json')
  with open(jsonfile, 'w') as f:
    json.dump(bigmeta, f, ensure_ascii=False, indent=4)

mip-NeRF代码debug
meta['pix2cam']
mip-NeRF代码debug
二维图片的坐标相机坐标系中的坐标 存在下面的转换关系:mip-NeRF代码debug
其中,矩阵 指的是 相机的内参,包含 焦距 (focal length) fx 以及 图像中心点的坐标 cx 。
k_inv 是 相机的内参矩阵的逆 即 pix2cam 图像到相机矩阵

对于相同的数据集,相机的内参矩阵一般是固定的,一般会在一个叫 intrinsics.txt 的文件里面给出。

train_blender.sh


SCENE=lego
EXPERIMENT=debug
TRAIN_DIR=/Users/barron/tmp/nerf_results/$EXPERIMENT/$SCENE
DATA_DIR=/Users/barron/data/nerf_synthetic/$SCENE

rm $TRAIN_DIR/*
python -m train \
  --data_dir=$DATA_DIR \
  --train_dir=$TRAIN_DIR \
  --gin_file=configs/blender.gin \
  --logtostderr

python -m

将库中的python模块用作脚本去运行

1.python xxx.py
2.python -m xxx.py

这是两种加载py文件的方式:

1叫做直接运行
2相当于import,叫做当做模块来启动

不同的加载py文件的方式,主要是影响sys.path这个属性。sys.path相当于Linux中的PATH。
这就是当前Python解析器运行的环境,Python解析器会在这些目录下去寻找依赖库。

运行脚本

将目录改成自己的数据集地址

TRAIN_DIR=/media/hp/C440AC3040AC2ADE/ysh/mipnerf-main/nerf_example_data/nerf_results/$EXPERIMENT/$SCENE
DATA_DIR=/media/hp/C440AC3040AC2ADE/ysh/mipnerf-main/nerf_example_data/nerf_synthetic/$SCENE
conda activate mipnerf
bash scripts/train_blender.sh

configs/blender.gin

Config.dataset_loader = 'blender'
Config.batching = 'single_image'

train.py

from absl import app
from absl import flags
from internal import datasets
from internal import math
from internal import models
from internal import utils
from internal import vis

config配置
FLAGS 作用及用法

FLAGS = flags.FLAGS
utils.define_common_flags()
flags.DEFINE_integer('render_every', 5000,
                     'The number of steps between test set image renderings.')

jax.config.parse_flags_with_absl()

utils.define_common_flags()

def define_common_flags():

  flags.DEFINE_multi_string('gin_file', None,
                            'List of paths to the config files.')
  flags.DEFINE_multi_string(
      'gin_param', None, 'Newline separated list of Gin parameter bindings.')
  flags.DEFINE_string('train_dir', None, 'where to store ckpts and logs')
  flags.DEFINE_string('data_dir', None, 'input data directory.')
  flags.DEFINE_integer(
      'chunk', 8192,
      'the size of chunks for evaluation inferences, set to the value that'
      'fits your GPU/TPU memory.')

main

1.config

def main(unused_argv):
  rng = random.PRNGKey(20200823)

  np.random.seed(20201473 + jax.host_id())

  config = utils.load_config()

  if config.batch_size % jax.device_count() != 0:
    raise ValueError('Batch size must be divisible by the number of devices.')

utils.load_config()

def load_config():
  gin.parse_config_files_and_bindings(flags.FLAGS.gin_file,
                                      flags.FLAGS.gin_param)
  return Config()

class Config

class Config:
  """Configuration flags for everything."""
  dataset_loader: str = 'multicam'
  batching: str = 'all_images'
  batch_size: int = 4096
  factor: int = 0
  spherify: bool = False
  render_path: bool = False
  llffhold: int = 8
  lr_init: float = 5e-4
  lr_final: float = 5e-6
  lr_delay_steps: int = 2500
  lr_delay_mult: float = 0.01
  grad_max_norm: float = 0.

  grad_max_val: float = 0.

  max_steps: int = 1000000
  save_every: int = 100000
  print_every: int = 100
  gc_every: int = 10000
  test_render_interval: int = 1
  disable_multiscale_loss: bool = False
  randomized: bool = True
  near: float = 2.

  far: float = 6.

  coarse_loss_mult: float = 0.1
  weight_decay_mult: float = 0.

  white_bkgd: bool = True

2.dataset

  dataset = datasets.get_dataset('train', FLAGS.data_dir, config)
  test_dataset = datasets.get_dataset('test', FLAGS.data_dir, config)

def get_dataset

def get_dataset(split, train_dir, config):
  return dataset_dict[config.dataset_loader](split, train_dir, config)
dataset_dict = {
    'blender': Blender,
    'llff': LLFF,
    'multicam': Multicam,
}

class Blender

class Blender(Dataset):
  """Blender Dataset."""

class Dataset

class Dataset(threading.Thread):
  """Dataset Base Class."""

  def __init__(self, split, data_dir, config):
    super(Dataset, self).__init__()
    self.queue = queue.Queue(3)
    self.daemon = True
    self.split = split
    self.data_dir = data_dir
    self.near = config.near
    self.far = config.far
    if split == 'train':
      self._train_init(config)
    elif split == 'test':
      self._test_init(config)
    else:
      raise ValueError(
          'the split argument should be either \'train\' or \'test\', set'
          'to {} here.'.format(split))
    self.batch_size = config.batch_size // jax.host_count()
    self.batching = config.batching
    self.render_path = config.render_path
    self.start()

mip-NeRF代码debug

def _train_init

  def _train_init(self, config):
    """Initialize training."""
    self._load_renderings(config)
    self._generate_rays()

    if config.batching == 'all_images':

      self.images = self.images.reshape([-1, 3])
      self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
                                       self.rays)
    elif config.batching == 'single_image':
      self.images = self.images.reshape([-1, self.resolution, 3])
      self.rays = utils.namedtuple_map(
          lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
    else:
      raise NotImplementedError(
          f'{config.batching} batching strategy is not implemented.')

def _generate_rays

描述了光线的具体生成方法,一条光线包含以下内容

    self.rays = utils.Rays(
        origins=origins,
        directions=directions,
        viewdirs=viewdirs,
        radii=radii,
        lossmult=ones,
        near=ones * self.near,
        far=ones * self.far)

  def _generate_rays(self):
    """Generating rays for all images."""
    x, y = np.meshgrid(
        np.arange(self.w, dtype=np.float32),
        np.arange(self.h, dtype=np.float32),
        indexing='xy')
    camera_dirs = np.stack(
        [(x - self.w * 0.5 + 0.5) / self.focal,
         -(y - self.h * 0.5 + 0.5) / self.focal, -np.ones_like(x)],
        axis=-1)
    directions = ((camera_dirs[None, ..., None, :] *
                   self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
    origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
                              directions.shape)
    viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

    dx = np.sqrt(
        np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], 1)

    radii = dx[..., None] * 2 / np.sqrt(12)

    ones = np.ones_like(origins[..., :1])
    self.rays = utils.Rays(
        origins=origins,
        directions=directions,
        viewdirs=viewdirs,
        radii=radii,
        lossmult=ones,
        near=ones * self.near,
        far=ones * self.far)

utils.Rays

Rays = collections.namedtuple(
    'Rays',
    ('origins', 'directions', 'viewdirs', 'radii', 'lossmult', 'near', 'far'))

mip-NeRF代码debug

3.model

  rng, key = random.split(rng)
  model, variables = models.construct_mipnerf(key, dataset.peek())
  num_params = jax.tree_util.tree_reduce(
      lambda x, y: x + jnp.prod(jnp.array(y.shape)), variables, initializer=0)
  print(f'Number of parameters being optimized: {num_params}')
  optimizer = flax.optim.Adam(config.lr_init).create(variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, variables

mip-NeRF代码debug

  def peek(self):
    """在不出列的情况下,查看下一批训练或测试示例.

    Returns:
      batch: dict, has 'pixels' and 'rays'.

"""
    x = self.queue.queue[0].copy()
    if self.split == 'train':
      return utils.shard(x)
    else:
      return utils.to_device(x)
def shard(xs):
  """沿着第一维将多个设备的数据分割成碎片。"""
  return jax.tree_map(
      lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)

def construct_mipnerf

def construct_mipnerf(rng, example_batch):
  """Construct a Neural Radiance Field.

  Args:
    rng: jnp.ndarray. 随机数生成器。
    example_batch: dict, an example of a batch of data.

  Returns:
    model: nn.Model. Nerf model with parameters.

    state: flax.Module.state. 有状态参数的Nerf模型状态.

"""
  model = MipNerfModel()
  key, rng = random.split(rng)
  init_variables = model.init(
      key,
      rng=rng,
      rays=utils.namedtuple_map(lambda x: x[0], example_batch['rays']),
      randomized=False,
      white_bkgd=False)
  return model, init_variables

model值

MipNerfModel(

    num_samples = 128
    num_levels = 2
    resample_padding = 0.01
    stop_level_grad = True
    use_viewdirs = True
    lindisp = False
    ray_shape = 'cone'
    min_deg_point = 0
    max_deg_point = 16
    deg_view = 4
    density_activation = softplus
    density_noise = 0.0
    density_bias = -1.0
    rgb_activation = sigmoid
    rgb_padding = 0.001
    disable_integration = False
)

4.学习率

  learning_rate_fn = functools.partial(
      math.learning_rate_decay,
      lr_init=config.lr_init,
      lr_final=config.lr_final,
      max_steps=config.max_steps,
      lr_delay_steps=config.lr_delay_steps,
      lr_delay_mult=config.lr_delay_mult)

math.learning_rate_decay

连续学习率衰减函数

def learning_rate_decay(step,
                        lr_init,
                        lr_final,
                        max_steps,
                        lr_delay_steps=0,
                        lr_delay_mult=1):
  """Continuous learning rate decay function.

  当步长=0时,返回的速率为lr_init,当步长=max_steps时,返回的速率为lr_final,
  并且 在别处是对数线性插值的(相当于指数衰减)。

  如果lr_delay_steps>0,那么学习速率将由lr_delay_mult的某个平滑函数来缩放,
  使得初始学习速率在优化开始时是lr_init*lr_delay_mult,
  但是当steps>lr_delay_steps时将被缓和回到正常学习速率。

  Args:
    step: int, the current optimization step.

    lr_init: float, the initial learning rate.

    lr_final: float, the final learning rate.

    max_steps: int, the number of steps during optimization.

    lr_delay_steps: int, the number of steps to delay the full learning rate.

    lr_delay_mult: float, the multiplier on the rate when delaying it.

  Returns:
    lr: the learning for current step 'step'.

"""
  if lr_delay_steps > 0:

    delay_rate = lr_delay_mult + (1 - lr_delay_mult) * jnp.sin(
        0.5 * jnp.pi * jnp.clip(step / lr_delay_steps, 0, 1))
  else:
    delay_rate = 1.

  t = jnp.clip(step / max_steps, 0, 1)
  log_lerp = jnp.exp(jnp.log(lr_init) * (1 - t) + jnp.log(lr_final) * t)
  return delay_rate * log_lerp

5.函数映射

  train_pstep = jax.pmap(
      functools.partial(train_step, model, config),
      axis_name='batch',
      in_axes=(0, 0, 0, None),
      donate_argnums=(2,))
  render_eval_pfn = jax.pmap(
      render_eval_fn,
      in_axes=(None, None, 0),
      donate_argnums=(2,),
      axis_name='batch',
  )
  ssim_fn = jax.jit(functools.partial(math.compute_ssim, max_val=1.))

6.加载已训练模型

  if not utils.isdir(FLAGS.train_dir):
    utils.makedirs(FLAGS.train_dir)
  state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)

  init_step = state.optimizer.state.step + 1
  state = flax.jax_utils.replicate(state)

7.summary_writer

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

8.迭代


  pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
  rng = rng + jax.host_id()
  keys = random.split(rng, jax.local_device_count())
  gc.disable()
  stats_trace = []
  reset_timer = True
  for step, batch in zip(range(init_step, config.max_steps + 1), pdataset):
    if reset_timer:
      t_loop_start = time.time()
      reset_timer = False
    lr = learning_rate_fn(step)
    state, stats, keys = train_pstep(keys, state, batch, lr)
    if jax.host_id() == 0:
      stats_trace.append(stats)
    if step % config.gc_every == 0:
      gc.collect()

Log training summaries

这被放在host_id检查之后,因为在多主机评估中,所有主机都需要运行推理,即使我们只使用host 0来记录结果。


    if jax.host_id() == 0:
      if step % config.print_every == 0:
        summary_writer.scalar('num_params', num_params, step)
        summary_writer.scalar('train_loss', stats.loss[0], step)
        summary_writer.scalar('train_psnr', stats.psnr[0], step)
        for i, l in enumerate(stats.losses[0]):
          summary_writer.scalar(f'train_losses_{i}', l, step)
        for i, p in enumerate(stats.psnrs[0]):
          summary_writer.scalar(f'train_psnrs_{i}', p, step)
        summary_writer.scalar('weight_l2', stats.weight_l2[0], step)
        avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
        avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
        max_grad_norm = np.max(
            np.concatenate([s.grad_norm for s in stats_trace]))
        avg_grad_norm = np.mean(
            np.concatenate([s.grad_norm for s in stats_trace]))
        max_clipped_grad_norm = np.max(
            np.concatenate([s.grad_norm_clipped for s in stats_trace]))
        max_grad_max = np.max(
            np.concatenate([s.grad_abs_max for s in stats_trace]))
        stats_trace = []
        summary_writer.scalar('train_avg_loss', avg_loss, step)
        summary_writer.scalar('train_avg_psnr', avg_psnr, step)
        summary_writer.scalar('train_max_grad_norm', max_grad_norm, step)
        summary_writer.scalar('train_avg_grad_norm', avg_grad_norm, step)
        summary_writer.scalar('train_max_clipped_grad_norm',
                              max_clipped_grad_norm, step)
        summary_writer.scalar('train_max_grad_max', max_grad_max, step)
        summary_writer.scalar('learning_rate', lr, step)
        steps_per_sec = config.print_every / (time.time() - t_loop_start)
        reset_timer = True
        rays_per_sec = config.batch_size * steps_per_sec
        summary_writer.scalar('train_steps_per_sec', steps_per_sec, step)
        summary_writer.scalar('train_rays_per_sec', rays_per_sec, step)
        precision = int(np.ceil(np.log10(config.max_steps))) + 1
        print(('{:' + '{:d}'.format(precision) + 'd}').format(step) +
              f'/{config.max_steps:d}: ' + f'i_loss={stats.loss[0]:0.4f}, ' +
              f'avg_loss={avg_loss:0.4f}, ' +
              f'weight_l2={stats.weight_l2[0]:0.2e}, ' + f'lr={lr:0.2e}, ' +
              f'{rays_per_sec:0.0f} rays/sec')

      if step % config.save_every == 0:
        state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(
            FLAGS.train_dir, state_to_save, int(step), keep=100)

Test-set evaluation


    if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:

      t_eval_start = time.time()
      eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
                                                   state)).optimizer.target
      test_case = next(test_dataset)
      pred_color, pred_distance, pred_acc = models.render_image(
          functools.partial(render_eval_pfn, eval_variables),
          test_case['rays'],
          keys[0],
          chunk=FLAGS.chunk)

      vis_suite = vis.visualize_suite(pred_distance, pred_acc)

      if jax.host_id() == 0:
        psnr = math.mse_to_psnr(((pred_color - test_case['pixels'])**2).mean())
        ssim = ssim_fn(pred_color, test_case['pixels'])
        eval_time = time.time() - t_eval_start
        num_rays = jnp.prod(jnp.array(test_case['rays'].directions.shape[:-1]))
        rays_per_sec = num_rays / eval_time
        summary_writer.scalar('test_rays_per_sec', rays_per_sec, step)
        print(f'Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec')
        summary_writer.scalar('test_psnr', psnr, step)
        summary_writer.scalar('test_ssim', ssim, step)
        summary_writer.image('test_pred_color', pred_color, step)
        for k, v in vis_suite.items():
          summary_writer.image('test_pred_' + k, v, step)
        summary_writer.image('test_pred_acc', pred_acc, step)
        summary_writer.image('test_target', test_case['pixels'], step)

存储最终的训练模型

  if config.max_steps % config.save_every != 0:
    state = jax.device_get(jax.tree_map(lambda x: x[0], state))
    checkpoints.save_checkpoint(
        FLAGS.train_dir, state, int(config.max_steps), keep=100)

def train_step

一个优化步骤

def train_step(model, config, rng, state, batch, lr):
  """One optimization step.

  Args:
    model: The linen model.

    config: The configuration.

    rng: jnp.ndarray, random number generator.

    state: utils.TrainState, state of the model/optimizer.

    batch: dict, a mini-batch of data for training.

    lr: float, real-time learning rate.

  Returns:
    new_state: utils.TrainState, new training state.

    stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].

    rng: jnp.ndarray, updated random number generator.

"""
  rng, key = random.split(rng)

  def loss_fn(variables):

    def tree_sum_fn(fn):
      return jax.tree_util.tree_reduce(
          lambda x, y: x + fn(y), variables, initializer=0)

    weight_l2 = config.weight_decay_mult * (
        tree_sum_fn(lambda z: jnp.sum(z**2)) /
        tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

    ret = model.apply(
        variables,
        key,
        batch['rays'],
        randomized=config.randomized,
        white_bkgd=config.white_bkgd)

    mask = batch['rays'].lossmult
    if config.disable_multiscale_loss:
      mask = jnp.ones_like(mask)

    losses = []
    for (rgb, _, _) in ret:
      losses.append(
          (mask * (rgb - batch['pixels'][..., :3])**2).sum() / mask.sum())
    losses = jnp.array(losses)

    loss = (
        config.coarse_loss_mult * jnp.sum(losses[:-1]) + losses[-1] + weight_l2)

    stats = utils.Stats(
        loss=loss,
        losses=losses,
        weight_l2=weight_l2,
        psnr=0.0,
        psnrs=0.0,
        grad_norm=0.0,
        grad_abs_max=0.0,
        grad_norm_clipped=0.0,
    )
    return loss, stats

  (_, stats), grad = (
      jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
  grad = jax.lax.pmean(grad, axis_name='batch')
  stats = jax.lax.pmean(stats, axis_name='batch')

  def tree_norm(tree):
    return jnp.sqrt(
        jax.tree_util.tree_reduce(
            lambda x, y: x + jnp.sum(y**2), tree, initializer=0))

  if config.grad_max_val > 0:
    clip_fn = lambda z: jnp.clip(z, -config.grad_max_val, config.grad_max_val)
    grad = jax.tree_util.tree_map(clip_fn, grad)

  grad_abs_max = jax.tree_util.tree_reduce(
      lambda x, y: jnp.maximum(x, jnp.max(jnp.abs(y))), grad, initializer=0)

  grad_norm = tree_norm(grad)
  if config.grad_max_norm > 0:
    mult = jnp.minimum(1, config.grad_max_norm / (1e-7 + grad_norm))
    grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
  grad_norm_clipped = tree_norm(grad)

  new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
  new_state = state.replace(optimizer=new_optimizer)

  psnrs = math.mse_to_psnr(stats.losses)
  stats = utils.Stats(
      loss=stats.loss,
      losses=stats.losses,
      weight_l2=stats.weight_l2,
      psnr=psnrs[-1],
      psnrs=psnrs,
      grad_norm=grad_norm,
      grad_abs_max=grad_abs_max,
      grad_norm_clipped=grad_norm_clipped,
  )

  return new_state, stats, rng

class MipNerfModel

@gin.configurable
class MipNerfModel(nn.Module):
  """Nerf NN Model with both coarse and fine MLPs."""
  num_samples: int = 128
  num_levels: int = 2
  resample_padding: float = 0.01
  stop_level_grad: bool = True
  use_viewdirs: bool = True
  lindisp: bool = False
  ray_shape: str = 'cone'
  min_deg_point: int = 0
  max_deg_point: int = 16
  deg_view: int = 4
  density_activation: Callable[..., Any] = nn.softplus
  density_noise: float = 0.

  density_bias: float = -1.

  rgb_activation: Callable[..., Any] = nn.sigmoid
  rgb_padding: float = 0.001
  disable_integration: bool = False

self值

MipNerfModel(

    num_samples = 128
    num_levels = 2
    resample_padding = 0.01
    stop_level_grad = True
    use_viewdirs = True
    lindisp = False
    ray_shape = 'cone'
    min_deg_point = 0
    max_deg_point = 16
    deg_view = 4
    density_activation = softplus
    density_noise = 0.0
    density_bias = -1.0
    rgb_activation = sigmoid
    rgb_padding = 0.001
    disable_integration = False

    MLP_0 = MLP(

        net_depth = 8
        net_width = 256
        net_depth_condition = 1
        net_width_condition = 128
        net_activation = relu
        skip_layer = 4
        num_rgb_channels = 3
        num_density_channels = 1
    )
)

def call

  @nn.compact
  def __call__(self, rng, rays, randomized, white_bkgd):
    """The mip-NeRF Model.

    Args:
      rng: jnp.ndarray, random number generator.

      rays: util.Rays, a namedtuple命名元组 of ray origins, directions, and viewdirs.

      randomized: bool, 使用随机分层抽样.

      white_bkgd: bool, if True, use white as the background (black o.w.).

    Returns:
      ret: list, [*(rgb, distance, acc)]
"""

    mlp = MLP()

    ret = []
    for i_level in range(self.num_levels):
      key, rng = random.split(rng)

分层采样策略

      if i_level == 0:

        t_vals, samples = mip.sample_along_rays(
            key,
            rays.origins,
            rays.directions,
            rays.radii,
            self.num_samples,
            rays.near,
            rays.far,
            randomized,
            self.lindisp,
            self.ray_shape,
        )
      else:
        t_vals, samples = mip.resample_along_rays(
            key,
            rays.origins,
            rays.directions,
            rays.radii,
            t_vals,
            weights,
            randomized,
            self.ray_shape,
            self.stop_level_grad,
            resample_padding=self.resample_padding,
        )

def sample_along_rays

沿射线分层取样

def sample_along_rays(key, origins, directions, radii, num_samples, near, far,
                      randomized, lindisp, ray_shape):
  """Stratified sampling along the rays.

  Args:
    key: jnp.ndarray, random generator key.

    origins: jnp.ndarray(float32), [batch_size, 3], ray origins.

    directions: jnp.ndarray(float32), [batch_size, 3], ray directions.

    radii: jnp.ndarray(float32), [batch_size, 3], ray radii.

    num_samples: int.

    near: jnp.ndarray, [batch_size, 1], near clip.

    far: jnp.ndarray, [batch_size, 1], far clip.

    randomized: bool, 使用**随机**分层抽样.

    lindisp: bool, 在视差而不是深度上线性采样.

    ray_shape: string, 假设光线为哪种形状.

  Returns:
    t_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.

    means: jnp.ndarray, [batch_size, num_samples, 3], sampled means.

    covs: jnp.ndarray, [batch_size, num_samples, 3, 3], sampled covariances协方差.

"""
  batch_size = origins.shape[0]

  t_vals = jnp.linspace(0., 1., num_samples + 1)
  if lindisp:
    t_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
  else:
    t_vals = near * (1. - t_vals) + far * t_vals

  if randomized:
    mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
    upper = jnp.concatenate([mids, t_vals[..., -1:]], -1)
    lower = jnp.concatenate([t_vals[..., :1], mids], -1)
    t_rand = random.uniform(key, [batch_size, num_samples + 1])
    t_vals = lower + (upper - lower) * t_rand
  else:

    t_vals = jnp.broadcast_to(t_vals, [batch_size, num_samples + 1])
  means, covs = cast_rays(t_vals, origins, directions, radii, ray_shape)
  return t_vals, (means, covs)

投射光线(圆锥形或圆柱形)并特征化其截面,求均值和协方差

def cast_rays(t_vals, origins, directions, radii, ray_shape, diag=True):
  """Cast rays (cone- or cylinder-shaped) and featurize sections of it.

  Args:
    t_vals: float array, the "fencepost" distances along the ray.沿着射线的"fencepost"距离。
    origins: float array, the ray origin coordinates.

    directions: float array, the ray direction vectors.

    radii: float array, the radii (base radii for cones) of the rays.光线的半径(圆锥的底半径)。
    ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'.

    diag: boolean, whether or not the covariance matrices should be diagonal.协方差矩阵是否应该是对角的。

  Returns:
    a tuple of arrays of means and covariances.

"""
  t0 = t_vals[..., :-1]
  t1 = t_vals[..., 1:]
  if ray_shape == 'cone':
    gaussian_fn = conical_frustum_to_gaussian
  elif ray_shape == 'cylinder':
    gaussian_fn = cylinder_to_gaussian
  else:
    assert False
  means, covs = gaussian_fn(directions, t0, t1, radii, diag)
  means = means + origins[..., None, :]
  return means, covs

t_vals

[[2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]
 [2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]
 [2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]
 ...

 [2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]
 [2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]
 [2.      2.03125 2.0625  ... 5.9375  5.96875 6.     ]]

位置编码 PE/IPE

self.disable_integration If True, use PE instead of IPE.

      if self.disable_integration:
        samples = (samples[0], jnp.zeros_like(samples[1]))
      samples_enc = mip.integrated_pos_enc(
          samples,
          self.min_deg_point,
          self.max_deg_point,
      )

def integrated_pos_enc
用2^[min_deg:max_deg-1缩放的正弦曲线编码"x"

def integrated_pos_enc(x_coord, min_deg, max_deg, diag=True):
  """Encode `x` with sinusoids scaled by 2^[min_deg:max_deg-1].

  Args:
    x_coord: a tuple containing: x, jnp.ndarray, variables to be encoded. Should
      be in [-pi, pi]. x_cov, jnp.ndarray, covariance matrices for `x`.

    min_deg: int, the min degree of the encoding.

    max_deg: int, the max degree of the encoding.

    diag: bool, if true, expects input covariances to be diagonal (full
      otherwise).

  Returns:
    encoded: jnp.ndarray, encoded variables.

"""
  if diag:
    x, x_cov_diag = x_coord
    scales = jnp.array([2**i for i in range(min_deg, max_deg)])
    shape = list(x.shape[:-1]) + [-1]
    y = jnp.reshape(x[..., None, :] * scales[:, None], shape)
    y_var = jnp.reshape(x_cov_diag[..., None, :] * scales[:, None]**2, shape)
  else:
    x, x_cov = x_coord
    num_dims = x.shape[-1]
    basis = jnp.concatenate(
        [2**i * jnp.eye(num_dims) for i in range(min_deg, max_deg)], 1)
    y = math.matmul(x, basis)

    y_var = jnp.sum((math.matmul(x_cov, basis)) * basis, -2)

  return expected_sin(
      jnp.concatenate([y, y + 0.5 * jnp.pi], axis=-1),
      jnp.concatenate([y_var] * 2, axis=-1))[0]

mip-NeRF代码debug
估计sin(z),z~N(x,var)的均值和方差

def expected_sin(x, x_var):
  """Estimates mean and variance of sin(z), z ~ N(x, var)."""

  y = jnp.exp(-0.5 * x_var) * math.safe_sin(x)
  y_var = jnp.maximum(
      0, 0.5 * (1 - jnp.exp(-2 * x_var) * math.safe_cos(2 * x)) - y**2)
  return y, y_var

点属性预测 raw_rgb, raw_density

raw_rgb, raw_density = mlp(samples_enc, viewdirs_enc)
raw_rgb, raw_density = mlp(samples_enc)


      if self.use_viewdirs:
        viewdirs_enc = mip.pos_enc(
            rays.viewdirs,
            min_deg=0,
            max_deg=self.deg_view,
            append_identity=True,
        )
        raw_rgb, raw_density = mlp(samples_enc, viewdirs_enc)
      else:
        raw_rgb, raw_density = mlp(samples_enc)

def pos_enc
原始NeRF文件使用的位置编码。
mip-NeRF代码debug

def pos_enc(x, min_deg, max_deg, append_identity=True):
  """The positional encoding used by the original NeRF paper."""
  scales = jnp.array([2**i for i in range(min_deg, max_deg)])
  xb = jnp.reshape((x[..., None, :] * scales[:, None]),
                   list(x.shape[:-1]) + [-1])
  four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
  if append_identity:
    return jnp.concatenate([x] + [four_feat], axis=-1)
  else:
    return four_feat

体积渲染

如果需要,添加噪音以调整密度预测。


      if randomized and (self.density_noise > 0):
        key, rng = random.split(rng)
        raw_density += self.density_noise * random.normal(
            key, raw_density.shape, dtype=raw_density.dtype)

      rgb = self.rgb_activation(raw_rgb)
      rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding
      density = self.density_activation(raw_density + self.density_bias)
      comp_rgb, distance, acc, weights = mip.volumetric_rendering(
          rgb,
          density,
          t_vals,
          rays.directions,
          white_bkgd=white_bkgd,
      )
      ret.append((comp_rgb, distance, acc))

    return ret

ret
两个level 三个属性值(comp_rgb, distance, acc)
mip-NeRF代码debug

def volumetric_rendering

mip-NeRF代码debug
mip-NeRF代码debug
σi : 密度

def volumetric_rendering(rgb, density, t_vals, dirs, white_bkgd):
  """Volumetric Rendering Function.

  Args:
    rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
    density: jnp.ndarray(float32), density, [batch_size, num_samples, 1].

    t_vals: jnp.ndarray(float32), [batch_size, num_samples].

    dirs: jnp.ndarray(float32), [batch_size, 3].

    white_bkgd: bool.

  Returns:
    comp_rgb: jnp.ndarray(float32), [batch_size, 3].

    disp: jnp.ndarray(float32), [batch_size].

    acc: jnp.ndarray(float32), [batch_size].

    weights: jnp.ndarray(float32), [batch_size, num_samples]
"""
  t_mids = 0.5 * (t_vals[..., :-1] + t_vals[..., 1:])
  t_dists = t_vals[..., 1:] - t_vals[..., :-1]
  delta = t_dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)

  density_delta = density[..., 0] * delta

  alpha = 1 - jnp.exp(-density_delta)
  trans = jnp.exp(-jnp.concatenate([
      jnp.zeros_like(density_delta[..., :1]),
      jnp.cumsum(density_delta[..., :-1], axis=-1)
  ],
                                   axis=-1))
  weights = alpha * trans

  comp_rgb = (weights[..., None] * rgb).sum(axis=-2)
  acc = weights.sum(axis=-1)
  distance = (weights * t_mids).sum(axis=-1) / acc
  distance = jnp.clip(
      jnp.nan_to_num(distance, jnp.inf), t_vals[:, 0], t_vals[:, -1])
  if white_bkgd:
    comp_rgb = comp_rgb + (1. - acc[..., None])
  return comp_rgb, distance, acc, weights

class MLP

@gin.configurable
class MLP(nn.Module):
  """A simple MLP."""
  net_depth: int = 8
  net_width: int = 256
  net_depth_condition: int = 1
  net_width_condition: int = 128
  net_activation: Callable[..., Any] = nn.relu
  skip_layer: int = 4
  num_rgb_channels: int = 3
  num_density_channels: int = 1

def call

  @nn.compact
  def __call__(self, x, condition=None):
    """Evaluate the MLP.

    Args:
      x: jnp.ndarray(float32), [batch, num_samples, feature], points.

      condition: jnp.ndarray(float32), [batch, feature], if not None, this
        variable will be part of the input to the second part of the MLP
        concatenated with the output vector of the first part of the MLP. If
        None, only the first part of the MLP will be used with input x. In the
        original paper, this variable is the view direction.

    Returns:
      raw_rgb: jnp.ndarray(float32), with a shape of
           [batch, num_samples, num_rgb_channels].

      raw_density: jnp.ndarray(float32), with a shape of
           [batch, num_samples, num_density_channels].

"""
    feature_dim = x.shape[-1]
    num_samples = x.shape[1]
    x = x.reshape([-1, feature_dim])
    dense_layer = functools.partial(
        nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
    inputs = x
    for i in range(self.net_depth):
      x = dense_layer(self.net_width)(x)
      x = self.net_activation(x)
      if i % self.skip_layer == 0 and i > 0:
        x = jnp.concatenate([x, inputs], axis=-1)
    raw_density = dense_layer(self.num_density_channels)(x).reshape(
        [-1, num_samples, self.num_density_channels])

    if condition is not None:

      bottleneck = dense_layer(self.net_width)(x)

      condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))

      condition = condition.reshape([-1, condition.shape[-1]])
      x = jnp.concatenate([bottleneck, condition], axis=-1)

      for i in range(self.net_depth_condition):
        x = dense_layer(self.net_width_condition)(x)
        x = self.net_activation(x)
    raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
        [-1, num_samples, self.num_rgb_channels])
    return raw_rgb, raw_density

mip-NeRF代码debug

jax

pmap

  train_pstep = jax.pmap(
      functools.partial(train_step, model, config),
      axis_name='batch',
      in_axes=(0, 0, 0, None),
      donate_argnums=(2,))

支持集体行动的并行映射

def pmap(
  fun: F,
  axis_name: Optional[AxisName] = None,
  *,
  in_axes=0,
  out_axes=0,
  static_broadcasted_argnums: Union[int, Iterable[int]] = (),
  devices: Optional[Sequence[xc.Device]] = None,
  backend: Optional[str] = None,
  axis_size: Optional[int] = None,
  donate_argnums: Union[int, Iterable[int]] = (),
  global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
) -> F:
  """Parallel map with support for collective operations.

  The purpose of :py:func:`pmap` 是表示单程序多数据 (SPMD)程序. Applying :py:func:`pmap` to a function
  将编译 函数(类似于:py:func:`jit ` ),然后并行执行它 在XLA设备上,例如多个GPU或多个TPU核心.

  语义上它与:py:func:`vmap '相当,因为两种转换都映射一个函数 阵列轴上,
  但是 :py:func:`vmap '通过将 将轴向下映射到原始操作,
  :py:func:`pmap '改为复制 函数,并在其自己的XLA设备上并行执行每个复制副本。

  映射的轴大小必须小于或等于本地XLA的数量 可用设备, 如:py:func:` jax . local _ device _ count()`(除非 指定了"设备",见下文).

  对于嵌套的:py:func:`pmap '调用,映射轴大小的乘积必须小于或等于XLA设备的数量。

  .. note::
    :py:func:`pmap` compiles ``fun``, 因此,虽然它可以与:py:func:`jit '结合使用,但通常是不必要的。

  **Multi-process platforms:** On multi-process platforms such as TPU pods,
  :py:func:`pmap` is designed to be used in SPMD Python programs, where every
  process is running the same Python code such that all processes run the same
  pmapped function in the same order. Each process should still call the pmapped
  function with mapped axis size equal to the number of *local* devices (unless
  ``devices`` is specified, see below), and an array of the same leading axis
  size will be returned as usual. However, any collective operations in ``fun``
  will be computed over *all* participating devices, including those on other
  processes, via device-to-device communication.  Conceptually, this can be
  thought of as running a pmap over a single array sharded across processes,
  where each process "sees" only its local shard of the input and output. The
  SPMD model requires that the same multi-process pmaps must be run in the same
  order on all devices, but they can be interspersed with arbitrary operations
  running in a single process.

  Args:
    fun: Function to be mapped over argument axes. Its arguments and return
      value should be arrays, scalars, or (nested) standard Python containers
      (tuple/list/dict) thereof. Positional arguments indicated by
      ``static_broadcasted_argnums`` can be anything at all, provided they are
      hashable and have an equality operation defined.

    axis_name: Optional, a hashable Python object used to identify the mapped
      axis so that parallel collectives can be applied.

    in_axes: A non-negative integer, None, or nested Python container thereof
      that specifies which axes of positional arguments to map over. Arguments
      passed as keywords are always mapped over their leading axis (i.e. axis
      index 0). See :py:func:`vmap` for details.

    out_axes: A non-negative integer, None, or nested Python container thereof
      indicating where the mapped axis should appear in the output. All outputs
      with a mapped axis must have a non-None ``out_axes`` specification
      (see :py:func:`vmap`).

    static_broadcasted_argnums: An int or collection of ints specifying which
      positional arguments to treat as static (compile-time constant).

      Operations that only depend on static arguments will be constant-folded.

      Calling the pmapped function with different values for these constants
      will trigger recompilation. If the pmapped function is called with fewer
      positional arguments than indicated by ``static_argnums`` then an error is
      raised. Each of the static arguments will be broadcasted to all devices.

      Arguments that are not arrays or containers thereof must be marked as
      static. Defaults to ().

    devices: This is an experimental feature and the API is likely to change.

      Optional, a sequence of Devices to map over. (Available devices can be
      retrieved via jax.devices()). Must be given identically for each process
      in multi-process settings (and will therefore include devices across
      processes). If specified, the size of the mapped axis must be equal to
      the number of devices in the sequence local to the given process. Nested
      :py:func:`pmap` s with ``devices`` specified in either the inner or outer
      :py:func:`pmap` are not yet supported.

    backend: This is an experimental feature and the API is likely to change.

      Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.

    axis_size: Optional; the size of the mapped axis.

    donate_argnums: Specify which arguments are "donated" to the computation.

      It is safe to donate arguments if you no longer need them once the
      computation has finished. In some cases XLA can make use of donated
      buffers to reduce the amount of memory needed to perform a computation,
      for example recycling one of your input buffers to store a result. You
      should not reuse buffers that you donate to a computation, JAX will raise
      an error if you try to.

    global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and
      the partitioned values span multiple processes. The global cross-process
      per-replica shape of each argument, i.e. does not include the leading
      pmapped dimension. Can be None for replicated arguments. This API is
      likely to change in the future.

  Returns:
    A parallelized version of ``fun`` with arguments that correspond to those of
    ``fun`` but with extra array axes at positions indicated by ``in_axes`` and
    with output that has an additional leading array axis (with the same size).

  For example, assuming 8 XLA devices are available, :py:func:`pmap` can be used
  as a map along a leading array axis:

  >>> import jax.numpy as jnp
  >>>
  >>> out = pmap(lambda x: x ** 2)(jnp.arange(8))  # doctest: +SKIP
  >>> print(out)  # doctest: +SKIP
  [0, 1, 4, 9, 16, 25, 36, 49]

  When the leading dimension is smaller than the number of available devices JAX
  will simply run on a subset of devices:

  >>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
  >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
  >>> out = pmap(jnp.dot)(x, y)  # doctest: +SKIP
  >>> print(out)  # doctest: +SKIP
  [[[    4.     9.]
    [   12.    29.]]
   [[  244.   345.]
    [  348.   493.]]
   [[ 1412.  1737.]
    [ 1740.  2141.]]]

  If your leading dimension is larger than the number of available devices you
  will get an error:

  >>> pmap(lambda x: x ** 2)(jnp.arange(9))  # doctest: +SKIP
  ValueError: ... requires 9 replicas, but only 8 XLA devices are available

  As with :py:func:`vmap`, using ``None`` in ``in_axes`` indicates that an
  argument doesn't have an extra axis and should be broadcasted, rather than
  mapped, across the replicas:

  >>> x, y = jnp.arange(2.), 4.

  >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)  # doctest: +SKIP
  >>> print(out)  # doctest: +SKIP
  ([4., 5.], [8., 8.])

  Note that :py:func:`pmap` always returns values mapped over their leading axis,
  equivalent to using ``out_axes=0`` in :py:func:`vmap`.

  In addition to expressing pure maps, :py:func:`pmap` can also be used to express
  parallel single-program multiple-data (SPMD) programs that communicate via
  collective operations. For example:

  >>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
  >>> out = pmap(f, axis_name='i')(jnp.arange(4.))  # doctest: +SKIP
  >>> print(out)  # doctest: +SKIP
  [ 0.          0.16666667  0.33333334  0.5       ]
  >>> print(out.sum())  # doctest: +SKIP
  1.0

  In this example, ``axis_name`` is a string, but it can be any Python object
  with ``__hash__`` and ``__eq__`` defined.

  The argument ``axis_name`` to :py:func:`pmap` names the mapped axis so that
  collective operations, like :func:`jax.lax.psum`, can refer to it. Axis names
  are important particularly in the case of nested :py:func:`pmap` functions,
  where collective operations can operate over distinct axes:

  >>> from functools import partial
  >>> import jax
  >>>
  >>> @partial(pmap, axis_name='rows')
  ... @partial(pmap, axis_name='cols')
  ... def normalize(x):
  ...   row_normed = x / jax.lax.psum(x, 'rows')
  ...   col_normed = x / jax.lax.psum(x, 'cols')
  ...   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
  ...   return row_normed, col_normed, doubly_normed
  >>>
  >>> x = jnp.arange(8.).reshape((4, 2))
  >>> row_normed, col_normed, doubly_normed = normalize(x)  # doctest: +SKIP
  >>> print(row_normed.sum(0))  # doctest: +SKIP
  [ 1.  1.]
  >>> print(col_normed.sum(1))  # doctest: +SKIP
  [ 1.  1.  1.  1.]
  >>> print(doubly_normed.sum((0, 1)))  # doctest: +SKIP
  1.0

  On multi-process platforms, collective operations operate over all devices,
  including those on other processes. For example, assuming the following code
  runs on two processes with 4 XLA devices each:

  >>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
  >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
  >>> out = pmap(f, axis_name='i')(data)  # doctest: +SKIP
  >>> print(out)  # doctest: +SKIP
  [28 29 30 31] # on process 0
  [32 33 34 35] # on process 1

  Each process passes in a different length-4 array, corresponding to its 4
  local devices, and the psum operates over all 8 values. Conceptually, the two
  length-4 arrays can be thought of as a sharded length-8 array (in this example
  equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped
  axis given name 'i'. The pmap call on each process then returns the
  corresponding length-4 output shard.

  The ``devices`` argument can be used to specify exactly which devices are used
  to run the parallel computation. For example, again assuming a single process
  with 8 devices, the following code defines two parallel computations, one
  which runs on the first six devices and one on the remaining two:

  >>> from functools import partial
  >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
  ... def f1(x):
  ...   return x / jax.lax.psum(x, axis_name='i')
  >>>
  >>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
  ... def f2(x):
  ...   return jax.lax.psum(x ** 2, axis_name='i')
  >>>
  >>> print(f1(jnp.arange(6.)))  # doctest: +SKIP
  [0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
  >>> print(f2(jnp.array([2., 3.])))  # doctest: +SKIP
  [ 13.  13.]
"""

Original: https://blog.csdn.net/qq_43620967/article/details/124495350
Author: ysh9888
Title: mip-NeRF代码debug



相关阅读

Title: 简洁优美的深度学习包-bert4keras

在鹅厂实习阶段,follow苏神(科学空间)的博客,启发了idea,成功改进了线上的一款模型。想法产出和实验进展很大一部分得益于苏神设计的bert4keras,清晰轻量、基于keras,可以很简洁的实现bert,同时附上了很多易读的example,对nlp新手及其友好!本文推荐几篇基于bert4keras的项目,均来自苏神,对新手入门bert比较合适~

  • tokenizer:分词器,主要方法:encode,decode。
  • build_transformer_model:建立bert模型,建议看源码,可以加载多种权重和模型结构(如unilm)。
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array

config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'

tokenizer = Tokenizer(dict_path, do_lower_case=True)
model = build_transformer_model(
    config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
)

token_ids, segment_ids = tokenizer.encode(u'科学技术是第一生产力')

token_ids[3] = token_ids[4] = tokenizer._token_mask_id
token_ids, segment_ids = to_array([token_ids], [segment_ids])

probas = model.predict([token_ids, segment_ids])[0]
print(tokenizer.decode(probas[3:5].argmax(axis=1)))

  • 句子1和句子2拼接在一起输入bert。
  • bert模型的pooler输出经dropout和mlp投影到2维空间,做分类问题。
  • 最终整个模型是一个标准的keras model。
class data_generator(DataGenerator):
    """数据生成器
"""
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, (text1, text2, label) in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(
                text1, text2, maxlen=maxlen
            )
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([label])
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []

bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    with_pool=True,
    return_keras_model=False,
)

output = Dropout(rate=0.1)(bert.model.output)
output = Dense(
    units=2, activation='softmax', kernel_initializer=bert.initializer
)(output)

model = keras.models.Model(bert.model.input, output)
model = build_transformer_model(
    config_path,
    checkpoint_path,
    application='unilm',
    keep_tokens=keep_tokens,
)

NLG任务的loss是交叉熵,示例中的实现很美观:

  • CrossEntropy类继承Loss类,重写compute_loss。
  • 将参与计算loss的变量过一遍CrossEntropy,这个过程中loss会被计算,具体阅读Loss类源码。
  • 最终整个模型是一个标准的keras model。
class CrossEntropy(Loss):
    """交叉熵作为loss,并mask掉输入部分
"""
    def compute_loss(self, inputs, mask=None):
        y_true, y_mask, y_pred = inputs
        y_true = y_true[:, 1:]
        y_mask = y_mask[:, 1:]
        y_pred = y_pred[:, :-1]
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss

model = build_transformer_model(
    config_path,
    checkpoint_path,
    application='unilm',
    keep_tokens=keep_tokens,
)

output = CrossEntropy(2)(model.inputs + model.outputs)

model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()

预测阶段自回归解码,继承AutoRegressiveDecoder类可以很容易实现beam_search。

项目地址:SimBert
融合了unilm和对比学习,data generator和loss类的设计很巧妙,值得仔细阅读,建议看不懂的地方打开jupyter对着一行一行print来理解。

bert4keras项目的优点:

  • build_transformer_model一句代码构建bert模型,一个参数即可切换为unilm结构。
  • 继承Loss类,重写compute_loss方法,很容易计算loss。
  • 深度基于keras,训练、保存和keras一致。
  • 丰富的example!苏神的前沿算法研究也会附上bert4keras实现。

Original: https://blog.csdn.net/weixin_44597588/article/details/123910248
Author: 一只用R的浣熊
Title: 简洁优美的深度学习包-bert4keras