media/libvpx/vp9/encoder/vp9_segmentation.c

changeset 0
6474c204b198
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/media/libvpx/vp9/encoder/vp9_segmentation.c	Wed Dec 31 06:09:35 2014 +0100
     1.3 @@ -0,0 +1,289 @@
     1.4 +/*
     1.5 + *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
     1.6 + *
     1.7 + *  Use of this source code is governed by a BSD-style license
     1.8 + *  that can be found in the LICENSE file in the root of the source
     1.9 + *  tree. An additional intellectual property rights grant can be found
    1.10 + *  in the file PATENTS.  All contributing project authors may
    1.11 + *  be found in the AUTHORS file in the root of the source tree.
    1.12 + */
    1.13 +
    1.14 +
    1.15 +#include <limits.h>
    1.16 +#include "vpx_mem/vpx_mem.h"
    1.17 +#include "vp9/encoder/vp9_segmentation.h"
    1.18 +#include "vp9/common/vp9_pred_common.h"
    1.19 +#include "vp9/common/vp9_tile_common.h"
    1.20 +
    1.21 +void vp9_enable_segmentation(VP9_PTR ptr) {
    1.22 +  VP9_COMP *cpi = (VP9_COMP *)ptr;
    1.23 +  struct segmentation *const seg =  &cpi->common.seg;
    1.24 +
    1.25 +  seg->enabled = 1;
    1.26 +  seg->update_map = 1;
    1.27 +  seg->update_data = 1;
    1.28 +}
    1.29 +
    1.30 +void vp9_disable_segmentation(VP9_PTR ptr) {
    1.31 +  VP9_COMP *cpi = (VP9_COMP *)ptr;
    1.32 +  struct segmentation *const seg =  &cpi->common.seg;
    1.33 +  seg->enabled = 0;
    1.34 +}
    1.35 +
    1.36 +void vp9_set_segmentation_map(VP9_PTR ptr,
    1.37 +                              unsigned char *segmentation_map) {
    1.38 +  VP9_COMP *cpi = (VP9_COMP *)ptr;
    1.39 +  struct segmentation *const seg = &cpi->common.seg;
    1.40 +
    1.41 +  // Copy in the new segmentation map
    1.42 +  vpx_memcpy(cpi->segmentation_map, segmentation_map,
    1.43 +             (cpi->common.mi_rows * cpi->common.mi_cols));
    1.44 +
    1.45 +  // Signal that the map should be updated.
    1.46 +  seg->update_map = 1;
    1.47 +  seg->update_data = 1;
    1.48 +}
    1.49 +
    1.50 +void vp9_set_segment_data(VP9_PTR ptr,
    1.51 +                          signed char *feature_data,
    1.52 +                          unsigned char abs_delta) {
    1.53 +  VP9_COMP *cpi = (VP9_COMP *)ptr;
    1.54 +  struct segmentation *const seg = &cpi->common.seg;
    1.55 +
    1.56 +  seg->abs_delta = abs_delta;
    1.57 +
    1.58 +  vpx_memcpy(seg->feature_data, feature_data, sizeof(seg->feature_data));
    1.59 +
    1.60 +  // TBD ?? Set the feature mask
    1.61 +  // vpx_memcpy(cpi->mb.e_mbd.segment_feature_mask, 0,
    1.62 +  //            sizeof(cpi->mb.e_mbd.segment_feature_mask));
    1.63 +}
    1.64 +
    1.65 +// Based on set of segment counts calculate a probability tree
    1.66 +static void calc_segtree_probs(int *segcounts, vp9_prob *segment_tree_probs) {
    1.67 +  // Work out probabilities of each segment
    1.68 +  const int c01 = segcounts[0] + segcounts[1];
    1.69 +  const int c23 = segcounts[2] + segcounts[3];
    1.70 +  const int c45 = segcounts[4] + segcounts[5];
    1.71 +  const int c67 = segcounts[6] + segcounts[7];
    1.72 +
    1.73 +  segment_tree_probs[0] = get_binary_prob(c01 + c23, c45 + c67);
    1.74 +  segment_tree_probs[1] = get_binary_prob(c01, c23);
    1.75 +  segment_tree_probs[2] = get_binary_prob(c45, c67);
    1.76 +  segment_tree_probs[3] = get_binary_prob(segcounts[0], segcounts[1]);
    1.77 +  segment_tree_probs[4] = get_binary_prob(segcounts[2], segcounts[3]);
    1.78 +  segment_tree_probs[5] = get_binary_prob(segcounts[4], segcounts[5]);
    1.79 +  segment_tree_probs[6] = get_binary_prob(segcounts[6], segcounts[7]);
    1.80 +}
    1.81 +
    1.82 +// Based on set of segment counts and probabilities calculate a cost estimate
    1.83 +static int cost_segmap(int *segcounts, vp9_prob *probs) {
    1.84 +  const int c01 = segcounts[0] + segcounts[1];
    1.85 +  const int c23 = segcounts[2] + segcounts[3];
    1.86 +  const int c45 = segcounts[4] + segcounts[5];
    1.87 +  const int c67 = segcounts[6] + segcounts[7];
    1.88 +  const int c0123 = c01 + c23;
    1.89 +  const int c4567 = c45 + c67;
    1.90 +
    1.91 +  // Cost the top node of the tree
    1.92 +  int cost = c0123 * vp9_cost_zero(probs[0]) +
    1.93 +             c4567 * vp9_cost_one(probs[0]);
    1.94 +
    1.95 +  // Cost subsequent levels
    1.96 +  if (c0123 > 0) {
    1.97 +    cost += c01 * vp9_cost_zero(probs[1]) +
    1.98 +            c23 * vp9_cost_one(probs[1]);
    1.99 +
   1.100 +    if (c01 > 0)
   1.101 +      cost += segcounts[0] * vp9_cost_zero(probs[3]) +
   1.102 +              segcounts[1] * vp9_cost_one(probs[3]);
   1.103 +    if (c23 > 0)
   1.104 +      cost += segcounts[2] * vp9_cost_zero(probs[4]) +
   1.105 +              segcounts[3] * vp9_cost_one(probs[4]);
   1.106 +  }
   1.107 +
   1.108 +  if (c4567 > 0) {
   1.109 +    cost += c45 * vp9_cost_zero(probs[2]) +
   1.110 +            c67 * vp9_cost_one(probs[2]);
   1.111 +
   1.112 +    if (c45 > 0)
   1.113 +      cost += segcounts[4] * vp9_cost_zero(probs[5]) +
   1.114 +              segcounts[5] * vp9_cost_one(probs[5]);
   1.115 +    if (c67 > 0)
   1.116 +      cost += segcounts[6] * vp9_cost_zero(probs[6]) +
   1.117 +              segcounts[7] * vp9_cost_one(probs[6]);
   1.118 +  }
   1.119 +
   1.120 +  return cost;
   1.121 +}
   1.122 +
   1.123 +static void count_segs(VP9_COMP *cpi, const TileInfo *const tile,
   1.124 +                       MODE_INFO **mi_8x8,
   1.125 +                       int *no_pred_segcounts,
   1.126 +                       int (*temporal_predictor_count)[2],
   1.127 +                       int *t_unpred_seg_counts,
   1.128 +                       int bw, int bh, int mi_row, int mi_col) {
   1.129 +  VP9_COMMON *const cm = &cpi->common;
   1.130 +  MACROBLOCKD *const xd = &cpi->mb.e_mbd;
   1.131 +  int segment_id;
   1.132 +
   1.133 +  if (mi_row >= cm->mi_rows || mi_col >= cm->mi_cols)
   1.134 +    return;
   1.135 +
   1.136 +  xd->mi_8x8 = mi_8x8;
   1.137 +  segment_id = xd->mi_8x8[0]->mbmi.segment_id;
   1.138 +
   1.139 +  set_mi_row_col(xd, tile, mi_row, bh, mi_col, bw, cm->mi_rows, cm->mi_cols);
   1.140 +
   1.141 +  // Count the number of hits on each segment with no prediction
   1.142 +  no_pred_segcounts[segment_id]++;
   1.143 +
   1.144 +  // Temporal prediction not allowed on key frames
   1.145 +  if (cm->frame_type != KEY_FRAME) {
   1.146 +    const BLOCK_SIZE bsize = mi_8x8[0]->mbmi.sb_type;
   1.147 +    // Test to see if the segment id matches the predicted value.
   1.148 +    const int pred_segment_id = vp9_get_segment_id(cm, cm->last_frame_seg_map,
   1.149 +                                                   bsize, mi_row, mi_col);
   1.150 +    const int pred_flag = pred_segment_id == segment_id;
   1.151 +    const int pred_context = vp9_get_pred_context_seg_id(xd);
   1.152 +
   1.153 +    // Store the prediction status for this mb and update counts
   1.154 +    // as appropriate
   1.155 +    vp9_set_pred_flag_seg_id(xd, pred_flag);
   1.156 +    temporal_predictor_count[pred_context][pred_flag]++;
   1.157 +
   1.158 +    if (!pred_flag)
   1.159 +      // Update the "unpredicted" segment count
   1.160 +      t_unpred_seg_counts[segment_id]++;
   1.161 +  }
   1.162 +}
   1.163 +
   1.164 +static void count_segs_sb(VP9_COMP *cpi, const TileInfo *const tile,
   1.165 +                          MODE_INFO **mi_8x8,
   1.166 +                          int *no_pred_segcounts,
   1.167 +                          int (*temporal_predictor_count)[2],
   1.168 +                          int *t_unpred_seg_counts,
   1.169 +                          int mi_row, int mi_col,
   1.170 +                          BLOCK_SIZE bsize) {
   1.171 +  const VP9_COMMON *const cm = &cpi->common;
   1.172 +  const int mis = cm->mode_info_stride;
   1.173 +  int bw, bh;
   1.174 +  const int bs = num_8x8_blocks_wide_lookup[bsize], hbs = bs / 2;
   1.175 +
   1.176 +  if (mi_row >= cm->mi_rows || mi_col >= cm->mi_cols)
   1.177 +    return;
   1.178 +
   1.179 +  bw = num_8x8_blocks_wide_lookup[mi_8x8[0]->mbmi.sb_type];
   1.180 +  bh = num_8x8_blocks_high_lookup[mi_8x8[0]->mbmi.sb_type];
   1.181 +
   1.182 +  if (bw == bs && bh == bs) {
   1.183 +    count_segs(cpi, tile, mi_8x8, no_pred_segcounts, temporal_predictor_count,
   1.184 +               t_unpred_seg_counts, bs, bs, mi_row, mi_col);
   1.185 +  } else if (bw == bs && bh < bs) {
   1.186 +    count_segs(cpi, tile, mi_8x8, no_pred_segcounts, temporal_predictor_count,
   1.187 +               t_unpred_seg_counts, bs, hbs, mi_row, mi_col);
   1.188 +    count_segs(cpi, tile, mi_8x8 + hbs * mis, no_pred_segcounts,
   1.189 +               temporal_predictor_count, t_unpred_seg_counts, bs, hbs,
   1.190 +               mi_row + hbs, mi_col);
   1.191 +  } else if (bw < bs && bh == bs) {
   1.192 +    count_segs(cpi, tile, mi_8x8, no_pred_segcounts, temporal_predictor_count,
   1.193 +               t_unpred_seg_counts, hbs, bs, mi_row, mi_col);
   1.194 +    count_segs(cpi, tile, mi_8x8 + hbs,
   1.195 +               no_pred_segcounts, temporal_predictor_count, t_unpred_seg_counts,
   1.196 +               hbs, bs, mi_row, mi_col + hbs);
   1.197 +  } else {
   1.198 +    const BLOCK_SIZE subsize = subsize_lookup[PARTITION_SPLIT][bsize];
   1.199 +    int n;
   1.200 +
   1.201 +    assert(bw < bs && bh < bs);
   1.202 +
   1.203 +    for (n = 0; n < 4; n++) {
   1.204 +      const int mi_dc = hbs * (n & 1);
   1.205 +      const int mi_dr = hbs * (n >> 1);
   1.206 +
   1.207 +      count_segs_sb(cpi, tile, &mi_8x8[mi_dr * mis + mi_dc],
   1.208 +                    no_pred_segcounts, temporal_predictor_count,
   1.209 +                    t_unpred_seg_counts,
   1.210 +                    mi_row + mi_dr, mi_col + mi_dc, subsize);
   1.211 +    }
   1.212 +  }
   1.213 +}
   1.214 +
   1.215 +void vp9_choose_segmap_coding_method(VP9_COMP *cpi) {
   1.216 +  VP9_COMMON *const cm = &cpi->common;
   1.217 +  struct segmentation *seg = &cm->seg;
   1.218 +
   1.219 +  int no_pred_cost;
   1.220 +  int t_pred_cost = INT_MAX;
   1.221 +
   1.222 +  int i, tile_col, mi_row, mi_col;
   1.223 +
   1.224 +  int temporal_predictor_count[PREDICTION_PROBS][2] = { { 0 } };
   1.225 +  int no_pred_segcounts[MAX_SEGMENTS] = { 0 };
   1.226 +  int t_unpred_seg_counts[MAX_SEGMENTS] = { 0 };
   1.227 +
   1.228 +  vp9_prob no_pred_tree[SEG_TREE_PROBS];
   1.229 +  vp9_prob t_pred_tree[SEG_TREE_PROBS];
   1.230 +  vp9_prob t_nopred_prob[PREDICTION_PROBS];
   1.231 +
   1.232 +  const int mis = cm->mode_info_stride;
   1.233 +  MODE_INFO **mi_ptr, **mi;
   1.234 +
   1.235 +  // Set default state for the segment tree probabilities and the
   1.236 +  // temporal coding probabilities
   1.237 +  vpx_memset(seg->tree_probs, 255, sizeof(seg->tree_probs));
   1.238 +  vpx_memset(seg->pred_probs, 255, sizeof(seg->pred_probs));
   1.239 +
   1.240 +  // First of all generate stats regarding how well the last segment map
   1.241 +  // predicts this one
   1.242 +  for (tile_col = 0; tile_col < 1 << cm->log2_tile_cols; tile_col++) {
   1.243 +    TileInfo tile;
   1.244 +
   1.245 +    vp9_tile_init(&tile, cm, 0, tile_col);
   1.246 +    mi_ptr = cm->mi_grid_visible + tile.mi_col_start;
   1.247 +    for (mi_row = 0; mi_row < cm->mi_rows;
   1.248 +         mi_row += 8, mi_ptr += 8 * mis) {
   1.249 +      mi = mi_ptr;
   1.250 +      for (mi_col = tile.mi_col_start; mi_col < tile.mi_col_end;
   1.251 +           mi_col += 8, mi += 8)
   1.252 +        count_segs_sb(cpi, &tile, mi, no_pred_segcounts,
   1.253 +                      temporal_predictor_count, t_unpred_seg_counts,
   1.254 +                      mi_row, mi_col, BLOCK_64X64);
   1.255 +    }
   1.256 +  }
   1.257 +
   1.258 +  // Work out probability tree for coding segments without prediction
   1.259 +  // and the cost.
   1.260 +  calc_segtree_probs(no_pred_segcounts, no_pred_tree);
   1.261 +  no_pred_cost = cost_segmap(no_pred_segcounts, no_pred_tree);
   1.262 +
   1.263 +  // Key frames cannot use temporal prediction
   1.264 +  if (!frame_is_intra_only(cm)) {
   1.265 +    // Work out probability tree for coding those segments not
   1.266 +    // predicted using the temporal method and the cost.
   1.267 +    calc_segtree_probs(t_unpred_seg_counts, t_pred_tree);
   1.268 +    t_pred_cost = cost_segmap(t_unpred_seg_counts, t_pred_tree);
   1.269 +
   1.270 +    // Add in the cost of the signaling for each prediction context.
   1.271 +    for (i = 0; i < PREDICTION_PROBS; i++) {
   1.272 +      const int count0 = temporal_predictor_count[i][0];
   1.273 +      const int count1 = temporal_predictor_count[i][1];
   1.274 +
   1.275 +      t_nopred_prob[i] = get_binary_prob(count0, count1);
   1.276 +
   1.277 +      // Add in the predictor signaling cost
   1.278 +      t_pred_cost += count0 * vp9_cost_zero(t_nopred_prob[i]) +
   1.279 +                     count1 * vp9_cost_one(t_nopred_prob[i]);
   1.280 +    }
   1.281 +  }
   1.282 +
   1.283 +  // Now choose which coding method to use.
   1.284 +  if (t_pred_cost < no_pred_cost) {
   1.285 +    seg->temporal_update = 1;
   1.286 +    vpx_memcpy(seg->tree_probs, t_pred_tree, sizeof(t_pred_tree));
   1.287 +    vpx_memcpy(seg->pred_probs, t_nopred_prob, sizeof(t_nopred_prob));
   1.288 +  } else {
   1.289 +    seg->temporal_update = 0;
   1.290 +    vpx_memcpy(seg->tree_probs, no_pred_tree, sizeof(no_pred_tree));
   1.291 +  }
   1.292 +}

mercurial