It is expected that changes made in Jax are not reflected in the PETSc object. The issue has been explained in my previous message (point 2).
Hong ________________________________ From: Alberto Cattaneo <bubu.catta...@gmail.com> Sent: Tuesday, July 15, 2025 1:13 PM To: Zhang, Hong <hongzh...@anl.gov> Subject: Re: [petsc-users] Petsc/Jax no copy interfacing issues Odd, I was using double precision (forgot to include that in the example, sorry) but on my machineI’m still not seeing the changes made reflected in the PETSc object. Are the changes reflected on your end? Is it possibly an ownership issue? ZjQcmQRYFpfptBannerStart This Message Is From an External Sender This message came from outside your organization. ZjQcmQRYFpfptBannerEnd Odd, I was using double precision (forgot to include that in the example, sorry) but on my machineI’m still not seeing the changes made reflected in the PETSc object. Are the changes reflected on your end? Is it possibly an ownership issue? On Tue, Jul 8, 2025 at 3:56 PM Zhang, Hong <hongzh...@anl.gov<mailto:hongzh...@anl.gov>> wrote: Hi Alberto, 1. To check the array pointer on the PETSc side, you can do print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because you configured PETSc in double precision, but JAX uses single precision by default. You can either add jax.config.update("jax_enable_x64", True) to make JAX use double precision number or configure PETSc to support single precision. 2. Once you fix this precision mismatch, the in-place conversion between PETSc and JAX should work. However, .at[].set() in JAX does not guarantee to operate in-place. The array updates in JAX are generally performed out-of-place by design. You may do the updates in PETSc so that it won’t break the zero-copy system. Hong From: petsc-users <petsc-users-boun...@mcs.anl.gov<mailto:petsc-users-boun...@mcs.anl.gov>> on behalf of Alberto Cattaneo <bubu.catta...@gmail.com<mailto:bubu.catta...@gmail.com>> Date: Monday, July 7, 2025 at 8:40 AM To: "petsc-users@mcs.anl.gov<mailto:petsc-users@mcs.anl.gov>" <petsc-users@mcs.anl.gov<mailto:petsc-users@mcs.anl.gov>> Subject: [petsc-users] Petsc/Jax no copy interfacing issues Greetings. I hope this email reaches you well. I’m trying to get JAX and PETSc to work together in a no-copy system using the DLPack tools in both. Unfortunately I can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec object using petsc4py, pass it to to a JAX object without copying, make a change to it in a JAX jitted function and have that change reflected in the PETSc object. All of this without copying. Of note: When I try to do this I get an error that the alignment is wrong and a copy must be made when I call the from-dlpack function but changing the alignment in the PETSc ./config stage to 32 causes the error message to disappear, even so it still doesn’t function correctly. I’ve tried looking through the documentation, but I’m getting a little turned around. I’ve included a code snippet below: from petsc4py import PETSc as PETSc import jax from functools import partial import jax.numpy as jnp @partial(jax.jit, donate_argnums=(0,)) def set_in_place(x): return x.at<https://urldefense.us/v3/__https://gcc02.safelinks.protection.outlook.com/?url=https*3A*2F*2Furldefense.us*2Fv3*2F__http*3A*2Fx.at__*3B!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33*24&data=05*7C02*7Cpetsc-users*40mcs.anl.gov*7Cd16ee49fac3e47f97dc408ddc4b64c19*7C0cfca18525f749e38ae7704d5326e285*7C0*7C0*7C638883009407911569*7CUnknown*7CTWFpbGZsb3d8eyJFbXB0eU1hcGkiOnRydWUsIlYiOiIwLjAuMDAwMCIsIlAiOiJXaW4zMiIsIkFOIjoiTWFpbCIsIldUIjoyfQ*3D*3D*7C0*7C*7C*7C&sdata=d9kaMxlXfBWQZwzB4IVvYeLz7Ru2LI64jLkBWJZnGHM*3D&reserved=0__;JSUlJSUlJSUlJSUlJSUlJSUlJSUlJSUlJSU!!G_uCfscf7eWS!bFuLQebXmxQbeA2pyUDtq86jpMg21z70S5aqiiqlj5t4fnkwEZVop06kAVrYZpzLLQ82f61tj5fSFWXdKUWRvbcB9Q$ >[:].set(3.0) print('\nTesting jax from_dlpack given a PETSc vector that was allocated by PETSc') x = jnp.ones((1000,1)) y_petsc = PETSc.Vec().createSeq(x.shape[0]) y_petsc.set(0.0) print(hex(y_petsc.handle)) y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw')) y2_petsc.set(-1.0) assert y_petsc.getValue(0) == y2_petsc.getValue(0) print('After creating a second PETSc vector via a DLPack of the first, modifying the memory of one affects the other.') #y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False) y = jnp.from_dlpack(y_petsc, copy=False) orig_ptr = y.unsafe_buffer_pointer() print(f'before: ptr at {hex(orig_ptr)}') y = set_in_place(y) print(f'after: ptr at {hex(y.unsafe_buffer_pointer())}') assert orig_ptr == y.unsafe_buffer_pointer() #assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} did not match the JAX value {y[0]}, so modifying the JAX memory did not affect the PETSc memory.' I’d like the bottom two asserts to pass, but I can only get one of them. If somebody is familiar with this issue I’d greatly appreciate the assistance. Respectfully: Alberto