It is expected that changes made in Jax are not reflected in the PETSc object. 
The issue has been explained in my previous message (point 2).

Hong
________________________________
From: Alberto Cattaneo <bubu.catta...@gmail.com>
Sent: Tuesday, July 15, 2025 1:13 PM
To: Zhang, Hong <hongzh...@anl.gov>
Subject: Re: [petsc-users] Petsc/Jax no copy interfacing issues

Odd, I was using double precision (forgot to include that in the example, 
sorry) but on my machineI’m still not seeing the changes made reflected in the 
PETSc object. Are the changes reflected on your end? Is it possibly an 
ownership issue?
ZjQcmQRYFpfptBannerStart
This Message Is From an External Sender
This message came from outside your organization.

ZjQcmQRYFpfptBannerEnd
Odd, I was using double precision (forgot to include that in the example, 
sorry) but on my machineI’m still not seeing the changes made reflected in the 
PETSc object. Are the changes reflected on your end? Is it possibly an 
ownership issue?

On Tue, Jul 8, 2025 at 3:56 PM Zhang, Hong 
<hongzh...@anl.gov<mailto:hongzh...@anl.gov>> wrote:

Hi Alberto,

1. To check the array pointer on the PETSc side, you can do 
print(hex(y_petsc.array.ctypes.data)). Then you will see a pointer mismatch 
caused by the line y = jnp.from_dlpack(y_petsc, copy=False). This is because 
you configured PETSc in double precision, but JAX uses single precision by 
default. You can either add jax.config.update("jax_enable_x64", True) to make 
JAX use double precision number or configure PETSc to support single precision.

2. Once you fix this precision mismatch, the in-place conversion between PETSc 
and JAX should work. However, .at[].set() in JAX does not guarantee to operate 
in-place. The array updates in JAX are generally performed out-of-place by 
design. You may do the updates in PETSc so that it won’t break the zero-copy 
system.

Hong



From: petsc-users 
<petsc-users-boun...@mcs.anl.gov<mailto:petsc-users-boun...@mcs.anl.gov>> on 
behalf of Alberto Cattaneo 
<bubu.catta...@gmail.com<mailto:bubu.catta...@gmail.com>>
Date: Monday, July 7, 2025 at 8:40 AM
To: "petsc-users@mcs.anl.gov<mailto:petsc-users@mcs.anl.gov>" 
<petsc-users@mcs.anl.gov<mailto:petsc-users@mcs.anl.gov>>
Subject: [petsc-users] Petsc/Jax no copy interfacing issues



Greetings.

I hope this email reaches you well. I’m trying to get JAX and PETSc to work 
together in a no-copy system using the DLPack tools in both. Unfortunately I 
can’t seem to get it to work right. Ideally, I’d like to create a PETSc vec 
object using petsc4py, pass it to to a JAX object without copying, make a 
change to it in a JAX jitted function and have that change reflected in the 
PETSc object. All of this without copying.



Of note: When I try to do this I get an error that the alignment is wrong and a 
copy must be made when I call the from-dlpack function but changing the 
alignment in the PETSc ./config stage to 32 causes the error message to 
disappear, even so it still doesn’t function correctly. I’ve tried looking 
through the documentation, but I’m getting a little turned around.



I’ve included a code snippet below:



from petsc4py import PETSc as PETSc

import jax

from functools import partial

import jax.numpy as jnp



@partial(jax.jit, donate_argnums=(0,))

def set_in_place(x):

    return 
x.at<https://urldefense.us/v3/__https://gcc02.safelinks.protection.outlook.com/?url=https*3A*2F*2Furldefense.us*2Fv3*2F__http*3A*2Fx.at__*3B!!G_uCfscf7eWS!cqxG3TobpS7WZAgzxjrlWaxhAiiwWk4i9-WKReIWrc04LoXg4Y8zCkEDYGm_l5GilInGXbyzJWrD3BPRaTPlZHhIdz33*24&data=05*7C02*7Cpetsc-users*40mcs.anl.gov*7Cd16ee49fac3e47f97dc408ddc4b64c19*7C0cfca18525f749e38ae7704d5326e285*7C0*7C0*7C638883009407911569*7CUnknown*7CTWFpbGZsb3d8eyJFbXB0eU1hcGkiOnRydWUsIlYiOiIwLjAuMDAwMCIsIlAiOiJXaW4zMiIsIkFOIjoiTWFpbCIsIldUIjoyfQ*3D*3D*7C0*7C*7C*7C&sdata=d9kaMxlXfBWQZwzB4IVvYeLz7Ru2LI64jLkBWJZnGHM*3D&reserved=0__;JSUlJSUlJSUlJSUlJSUlJSUlJSUlJSUlJSU!!G_uCfscf7eWS!bFuLQebXmxQbeA2pyUDtq86jpMg21z70S5aqiiqlj5t4fnkwEZVop06kAVrYZpzLLQ82f61tj5fSFWXdKUWRvbcB9Q$
 >[:].set(3.0)



print('\nTesting jax from_dlpack given a PETSc vector that was allocated by 
PETSc')

x = jnp.ones((1000,1))

y_petsc = PETSc.Vec().createSeq(x.shape[0])

y_petsc.set(0.0)

print(hex(y_petsc.handle))

y2_petsc = PETSc.Vec().createWithDLPack(y_petsc.toDLPack('rw'))

y2_petsc.set(-1.0)

assert y_petsc.getValue(0) == y2_petsc.getValue(0)

print('After creating a second PETSc vector via a DLPack of the first, 
modifying the memory of one affects the other.')

#y = jnp.from_dlpack(y_petsc.toDLPack('rw'), copy=False)

y = jnp.from_dlpack(y_petsc, copy=False)

orig_ptr = y.unsafe_buffer_pointer()

print(f'before: ptr at {hex(orig_ptr)}')

y = set_in_place(y)

print(f'after:  ptr at {hex(y.unsafe_buffer_pointer())}')

assert orig_ptr == y.unsafe_buffer_pointer()

#assert y_petsc.getValue(0) == y[0], f'The PETSc value {y_petsc.getValue(0)} 
did not match the JAX value {y[0]}, so modifying the JAX memory did not affect 
the PETSc memory.'



I’d like the bottom two asserts to pass, but I can only get one of them. If 
somebody is familiar with this issue I’d greatly appreciate the assistance.

Respectfully:

Alberto

Reply via email to