Skip to content

Commit 9f54db6

Browse files
authored
feat: translate and rotate domains to clip the main frame (#130)
* Using PV Transform filter * change for full vtkTransform inheritance * using vtkLandmarkTransform and vtkOBBTree * some doc and accessible tests
1 parent 8081656 commit 9f54db6

File tree

3 files changed

+460
-0
lines changed

3 files changed

+460
-0
lines changed

docs/geos_mesh_docs/processing.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ geos.mesh.processing.SplitMesh filter
3333
--------------------------------------
3434

3535
.. automodule:: geos.mesh.processing.SplitMesh
36+
:members:
37+
:undoc-members:
38+
:show-inheritance:
39+
40+
geos.mesh.processing.ClipToMainFrame filter
41+
--------------------------------------------
42+
43+
.. automodule:: geos.mesh.processing.ClipToMainFrame
3644
:members:
3745
:undoc-members:
3846
:show-inheritance:
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# SPDX-License-Identifier: Apache 2.0
2+
# SPDX-FileCopyrightText: Copyright 2023-2025 TotalEnergies
3+
# SPDX-FileContributor: Jacques Franc
4+
5+
from vtkmodules.numpy_interface import dataset_adapter as dsa
6+
from vtkmodules.vtkCommonCore import vtkPoints
7+
from vtkmodules.vtkCommonMath import vtkMatrix4x4
8+
from vtkmodules.vtkFiltersGeneral import vtkOBBTree
9+
from vtkmodules.vtkFiltersGeometry import vtkDataSetSurfaceFilter
10+
from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkMultiBlockDataSet, vtkDataObjectTreeIterator, vtkPolyData
11+
from vtkmodules.vtkCommonTransforms import vtkLandmarkTransform
12+
from vtkmodules.vtkFiltersGeneral import vtkTransformFilter
13+
14+
from geos.utils.Logger import logging, Logger, getLogger
15+
16+
from geos.mesh.utils.genericHelpers import getMultiBlockBounds
17+
18+
import numpy as np
19+
import numpy.typing as npt
20+
21+
from typing import Tuple
22+
23+
__doc__ = """
24+
Module to clip a mesh to the main frame using rigid body transformation.
25+
26+
Methods include:
27+
- ClipToMainFrameElement class to compute the transformation
28+
- ClipToMainFrame class to apply the transformation to a mesh
29+
30+
To use it:
31+
32+
.. code-block:: python
33+
34+
from geos.mesh.processing.ClipToMainFrame import ClipToMainFrame
35+
36+
# Filter inputs.
37+
multiBlockDataSet: vtkMultiBlockDataSet
38+
# Optional Inputs
39+
speHandler : bool
40+
41+
# Instantiate the filter.
42+
filter: ClipToMainFrame = ClipToMainFrame()
43+
filter.SetInputData( multiBlockDataSet )
44+
45+
# Set the handler of yours (only if speHandler is True).
46+
yourHandler: logging.Handler
47+
filter.setLoggerHandler( yourHandler )
48+
49+
# Do calculations.
50+
filter.ComputeTransform()
51+
filter.Update()
52+
output: vtkMultiBlockDataSet = filter.GetOutput()
53+
54+
"""
55+
56+
57+
class ClipToMainFrameElement( vtkLandmarkTransform ):
58+
59+
sourcePts: vtkPoints
60+
targetPts: vtkPoints
61+
mesh: vtkUnstructuredGrid
62+
63+
def __init__( self, mesh: vtkUnstructuredGrid ) -> None:
64+
"""Clip mesh to main frame.
65+
66+
Args:
67+
mesh (vtkUnstructuredGrid): Mesh to clip.
68+
"""
69+
super().__init__()
70+
self.mesh = mesh
71+
72+
def Update( self ) -> None:
73+
"""Update the transformation."""
74+
self.sourcePts, self.targetPts = self.__getFramePoints( self.__getOBBTree( self.mesh ) )
75+
self.SetSourceLandmarks( self.sourcePts )
76+
self.SetTargetLandmarks( self.targetPts )
77+
self.SetModeToRigidBody()
78+
super().Update()
79+
80+
def __str__( self ) -> str:
81+
"""String representation of the transformation."""
82+
return super().__str__() + f"\nSource points: {self.sourcePts}" \
83+
+ f"\nTarget points: {self.targetPts}" \
84+
+ f"\nAngle-Axis: {self.__getAngleAxis()}" \
85+
+ f"\nTranslation: {self.__getTranslation()}"
86+
87+
def __getAngleAxis( self ) -> tuple[ float, npt.NDArray[ np.double ] ]:
88+
"""Get the angle and axis of the rotation.
89+
90+
tuple[float, npt.NDArray[np.double]]: Angle in degrees and axis of rotation.
91+
"""
92+
matrix: vtkMatrix4x4 = self.GetMatrix()
93+
angle: np.double = np.arccos(
94+
( matrix.GetElement( 0, 0 ) + matrix.GetElement( 1, 1 ) + matrix.GetElement( 2, 2 ) - 1 ) / 2 )
95+
if angle == 0:
96+
return 0.0, np.array( [ 1.0, 0.0, 0.0 ] )
97+
rx: float = matrix.GetElement( 2, 1 ) - matrix.GetElement( 1, 2 )
98+
ry: float = matrix.GetElement( 0, 2 ) - matrix.GetElement( 2, 0 )
99+
rz: float = matrix.GetElement( 1, 0 ) - matrix.GetElement( 0, 1 )
100+
r = np.array( [ rx, ry, rz ] )
101+
r /= np.linalg.norm( r )
102+
return np.degrees( angle ), r
103+
104+
def __getTranslation( self ) -> npt.NDArray[ np.double ]:
105+
"""Get the translation vector.
106+
107+
Returns:
108+
npt.NDArray[ np.double ]: The translation vector.
109+
"""
110+
matrix: vtkMatrix4x4 = self.GetMatrix()
111+
return np.array( [ matrix.GetElement( 0, 3 ), matrix.GetElement( 1, 3 ), matrix.GetElement( 2, 3 ) ] )
112+
113+
def __getOBBTree( self, mesh: vtkUnstructuredGrid ) -> vtkPoints:
114+
"""Get the OBB tree of the mesh.
115+
116+
Args:
117+
mesh (vtkUnstructuredGrid): Mesh to get the OBB tree from.
118+
119+
Returns:
120+
vtkPoints: Points from the 0-level OBB tree of the mesh. Fallback on Axis Aligned Bounding Box
121+
"""
122+
OBBTree = vtkOBBTree()
123+
surfFilter = vtkDataSetSurfaceFilter()
124+
surfFilter.SetInputData( mesh )
125+
surfFilter.Update()
126+
OBBTree.SetDataSet( surfFilter.GetOutput() )
127+
OBBTree.BuildLocator()
128+
pdata = vtkPolyData()
129+
OBBTree.GenerateRepresentation( 0, pdata )
130+
# at level 0 this should return 8 corners of the bounding box or fallback on AABB
131+
if pdata.GetNumberOfPoints() < 3:
132+
return self.__allpoints( mesh.GetBounds() )
133+
134+
return pdata.GetPoints()
135+
136+
def __allpoints( self, bounds: tuple[ float, float, float, float, float, float ] ) -> vtkPoints:
137+
"""Get the 8 corners of the bounding box.
138+
139+
Args:
140+
bounds (tuple[float, float, float, float, float, float]): Bounding box.
141+
142+
Returns:
143+
vtkPoints: 8 corners of the bounding box.
144+
"""
145+
pts = vtkPoints()
146+
pts.SetNumberOfPoints( 8 )
147+
pts.SetPoint( 0, [ bounds[ 0 ], bounds[ 2 ], bounds[ 4 ] ] )
148+
pts.SetPoint( 1, [ bounds[ 1 ], bounds[ 2 ], bounds[ 4 ] ] )
149+
pts.SetPoint( 2, [ bounds[ 1 ], bounds[ 3 ], bounds[ 4 ] ] )
150+
pts.SetPoint( 3, [ bounds[ 0 ], bounds[ 3 ], bounds[ 4 ] ] )
151+
pts.SetPoint( 4, [ bounds[ 0 ], bounds[ 2 ], bounds[ 5 ] ] )
152+
pts.SetPoint( 5, [ bounds[ 1 ], bounds[ 2 ], bounds[ 5 ] ] )
153+
pts.SetPoint( 6, [ bounds[ 1 ], bounds[ 3 ], bounds[ 5 ] ] )
154+
pts.SetPoint( 7, [ bounds[ 0 ], bounds[ 3 ], bounds[ 5 ] ] )
155+
return pts
156+
157+
def __getFramePoints( self, vpts: vtkPoints ) -> tuple[ vtkPoints, vtkPoints ]:
158+
"""Get the source and target points for the transformation.
159+
160+
Args:
161+
vpts (vtkPoints): Points to transform.
162+
163+
Returns:
164+
tuple[vtkPoints, vtkPoints]: Source and target points for the transformation.
165+
"""
166+
pts: npt.NDArray[ np.double ] = dsa.numpy_support.vtk_to_numpy( vpts.GetData() )
167+
#translate pts so they always lie on the -z,-y,-x quadrant
168+
off: npt.NDArray[ np.double ] = np.asarray( [
169+
-2 * np.amax( np.abs( pts[ :, 0 ] ) ), -2 * np.amax( np.abs( pts[ :, 1 ] ) ),
170+
-2 * np.amax( np.abs( pts[ :, 2 ] ) )
171+
] )
172+
pts += off
173+
further_ix: np.int_ = np.argmax( np.linalg.norm(
174+
pts, axis=1 ) ) # by default take the min point furthest from origin
175+
org: npt.NDArray = pts[ further_ix, : ]
176+
177+
# find 3 orthogonal vectors
178+
# we assume points are on a box
179+
dist_indexes: npt.NDArray[ np.int_ ] = np.argsort( np.linalg.norm( pts - org, axis=1 ) )
180+
# find u,v,w
181+
v1: npt.NDArray[ np.double ] = pts[ dist_indexes[ 1 ], : ] - org
182+
v2: npt.NDArray[ np.double ] = pts[ dist_indexes[ 2 ], : ] - org
183+
v1 /= np.linalg.norm( v1 )
184+
v2 /= np.linalg.norm( v2 )
185+
if np.abs( v1[ 0 ] ) > np.abs( v2[ 0 ] ):
186+
v1, v2 = v2, v1
187+
188+
# ensure orthogonality
189+
v2 -= np.dot( v2, v1 ) * v1
190+
v2 /= np.linalg.norm( v2 )
191+
v3: npt.NDArray[ np.double ] = np.cross( v1, v2 )
192+
v3 /= np.linalg.norm( v3 )
193+
194+
#reorder axis if v3 points downward
195+
if v3[ 2 ] < 0:
196+
v3 = -v3
197+
v1, v2 = v2, v1
198+
199+
sourcePts = vtkPoints()
200+
sourcePts.SetNumberOfPoints( 4 )
201+
sourcePts.SetPoint( 0, list( org - off ) )
202+
sourcePts.SetPoint( 1, list( v1 + org - off ) )
203+
sourcePts.SetPoint( 2, list( v2 + org - off ) )
204+
sourcePts.SetPoint( 3, list( v3 + org - off ) )
205+
206+
targetPts = vtkPoints()
207+
targetPts.SetNumberOfPoints( 4 )
208+
targetPts.SetPoint( 0, [ 0., 0., 0. ] )
209+
targetPts.SetPoint( 1, [ 1., 0., 0. ] )
210+
targetPts.SetPoint( 2, [ 0., 1., 0. ] )
211+
targetPts.SetPoint( 3, [ 0., 0., 1. ] )
212+
213+
return ( sourcePts, targetPts )
214+
215+
216+
loggerTitle: str = "Clip mesh to main frame."
217+
218+
219+
class ClipToMainFrame( vtkTransformFilter ):
220+
"""Filter to clip a mesh to the main frame using ClipToMainFrame class."""
221+
222+
def __init__( self, speHandler: bool = False, **properties: str ) -> None:
223+
"""Initialize the ClipToMainFrame Filter with optional speHandler args and forwarding properties to main class.
224+
225+
Args:
226+
speHandler (bool, optional): True to use a specific handler, False to use the internal handler.
227+
Defaults to False.
228+
properties (kwargs): kwargs forwarded to vtkTransformFilter.
229+
"""
230+
super().__init__( **properties )
231+
# Logger.
232+
self.logger: Logger
233+
if not speHandler:
234+
self.logger = getLogger( loggerTitle, True )
235+
else:
236+
self.logger = logging.getLogger( loggerTitle )
237+
self.logger.setLevel( logging.INFO )
238+
239+
def ComputeTransform( self ) -> None:
240+
"""Update the transformation."""
241+
# dispatch to ClipToMainFrame depending on input type
242+
if isinstance( self.GetInput(), vtkMultiBlockDataSet ):
243+
#locate reference point
244+
try:
245+
idBlock = self.__locate_reference_point( self.GetInput() )
246+
except IndexError:
247+
self.logger.error( "Reference point is not in the domain" )
248+
249+
clip = ClipToMainFrameElement( self.GetInput().GetDataSet( idBlock ) )
250+
else:
251+
clip = ClipToMainFrameElement( self.GetInput() )
252+
253+
clip.Update()
254+
self.SetTransform( clip )
255+
256+
def SetLoggerHandler( self, handler: logging.Handler ) -> None:
257+
"""Set a specific handler for the filter logger. In this filter 4 log levels are use, .info, .error, .warning and .critical, be sure to have at least the same 4 levels.
258+
259+
Args:
260+
handler (logging.Handler): The handler to add.
261+
"""
262+
if not self.logger.hasHandlers():
263+
self.logger.addHandler( handler )
264+
else:
265+
self.logger.warning(
266+
"The logger already has an handler, to use yours set the argument 'speHandler' to True during the filter initialization."
267+
)
268+
269+
def __locate_reference_point( self, multiBlockDataSet: vtkMultiBlockDataSet ) -> int:
270+
"""Locate the block to use as reference for the transformation.
271+
272+
Args:
273+
multiBlockDataSet (vtkMultiBlockDataSet): Input multiblock mesh.
274+
275+
Returns:
276+
int: Index of the block to use as reference.
277+
"""
278+
279+
def __inside( pt: npt.NDArray[ np.double ], bounds: tuple[ float, float, float, float, float, float ] ) -> bool:
280+
"""Check if a point is inside a box.
281+
282+
Args:
283+
pt (npt.NDArray[np.double]): Point to check.
284+
bounds (tuple[float, float, float, float, float, float]): Bounding box.
285+
286+
Returns:
287+
bool: True if the point is inside the bounding box, False otherwise.
288+
"""
289+
self.logger.info( f"Checking if point {pt} is inside bounds {bounds}" )
290+
return ( pt[ 0 ] >= bounds[ 0 ] and pt[ 0 ] <= bounds[ 1 ] and pt[ 1 ] >= bounds[ 2 ]
291+
and pt[ 1 ] <= bounds[ 3 ] and pt[ 2 ] >= bounds[ 4 ] and pt[ 2 ] <= bounds[ 5 ] )
292+
293+
DOIterator: vtkDataObjectTreeIterator = vtkDataObjectTreeIterator()
294+
DOIterator.SetDataSet( multiBlockDataSet )
295+
DOIterator.VisitOnlyLeavesOn()
296+
DOIterator.GoToFirstItem()
297+
xmin: float
298+
ymin: float
299+
zmin: float
300+
xmin, _, ymin, _, zmin, _ = getMultiBlockBounds( multiBlockDataSet )
301+
while DOIterator.GetCurrentDataObject() is not None:
302+
dataSet: vtkUnstructuredGrid = vtkUnstructuredGrid.SafeDownCast( DOIterator.GetCurrentDataObject() )
303+
bounds: Tuple[ float, float, float, float, float, float ] = dataSet.GetBounds()
304+
#use the furthest bounds corner as reference point in the all negs quadrant
305+
if __inside( np.asarray( [ xmin, ymin, zmin ] ), bounds ):
306+
self.logger.info( f"Using block {DOIterator.GetCurrentFlatIndex()} as reference for transformation" )
307+
return DOIterator.GetCurrentFlatIndex()
308+
DOIterator.GoToNextItem()
309+
310+
raise IndexError

0 commit comments

Comments
 (0)