/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>

#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

///@defgroup embedding-cuda Embedding CUDA Operators

///@ingroup embedding-cuda
void bounds_check_indices_cuda(
    Tensor& rows_per_table,
    Tensor& indices,
    Tensor& offsets,
    int64_t bounds_check_mode,
    Tensor& warning,
    const c10::optional<Tensor>& weights,
    const c10::optional<Tensor>& B_ofsets,
    const int64_t max_B);

// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
  DISPATCH_TO_CUDA("bounds_check_indices", bounds_check_indices_cuda);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
  DISPATCH_TO_CUDA("bounds_check_indices", bounds_check_indices_cuda);
}
