Open3D (C++ API)  0.17.0
Loading...
Searching...
No Matches
BallQueryOpKernel.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
10#include "../TensorFlowHelper.h"
11#include "tensorflow/core/framework/op.h"
12#include "tensorflow/core/framework/op_kernel.h"
13#include "tensorflow/core/lib/core/errors.h"
14
15class BallQueryOpKernel : public tensorflow::OpKernel {
16public:
17 explicit BallQueryOpKernel(tensorflow::OpKernelConstruction* construction)
18 : OpKernel(construction) {
19 using namespace tensorflow;
20
21 OP_REQUIRES_OK(construction,
22 construction->GetAttr("nsample", &nsample));
23 OP_REQUIRES_OK(construction, construction->GetAttr("radius", &radius));
24 OP_REQUIRES(
25 construction, nsample > 0,
26 errors::InvalidArgument("BallQuery expects positive nsample"));
27 }
28
29 void Compute(tensorflow::OpKernelContext* context) override {
30 using namespace tensorflow;
31
32 const Tensor& inp_tensor = context->input(0);
33 OP_REQUIRES(
34 context,
35 inp_tensor.dims() == 3 && inp_tensor.shape().dim_size(2) == 3,
36 errors::InvalidArgument("BallQuery expects "
37 "(batch_size,num_points,3) inp shape"));
38 int batch_size = inp_tensor.shape().dim_size(0);
39 int pts_size = inp_tensor.shape().dim_size(1);
40 auto inp_flat = inp_tensor.flat<float>();
41 const float* inp = &(inp_flat(0));
42
43 const Tensor& center_tensor = context->input(1);
44 OP_REQUIRES(context,
45 center_tensor.dims() == 3 &&
46 center_tensor.shape().dim_size(2) == 3,
47 errors::InvalidArgument(
48 "BallQuery expects "
49 "(batch_size,num_points,3) center shape"));
50 int ball_size = center_tensor.shape().dim_size(1);
51 auto center_flat = center_tensor.flat<float>();
52 const float* center = &(center_flat(0));
53
54 Tensor* out_tensor;
55 OP_REQUIRES_OK(context,
56 context->allocate_output(
57 0, TensorShape{batch_size, ball_size, nsample},
58 &out_tensor));
59 auto out_flat = out_tensor->flat<int>();
60 int* out = &(out_flat(0));
61
62 Kernel(context, batch_size, pts_size, ball_size, radius, nsample,
63 center, inp, out);
64 }
65
66 virtual void Kernel(tensorflow::OpKernelContext* context,
67 int b,
68 int n,
69 int m,
70 float radius,
71 int nsample,
72 const float* new_xyz,
73 const float* xyz,
74 int* idx) = 0;
75
76protected:
78 float radius;
79};
ImGuiContext * context
Definition Window.cpp:76
Definition BallQueryOpKernel.h:15
int nsample
Definition BallQueryOpKernel.h:77
BallQueryOpKernel(tensorflow::OpKernelConstruction *construction)
Definition BallQueryOpKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition BallQueryOpKernel.h:29
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx)=0
float radius
Definition BallQueryOpKernel.h:78