https://github.com/python/cpython/commit/61f2a1a5993967ed4b97ba93a4477c37fe68cf59
commit: 61f2a1a5993967ed4b97ba93a4477c37fe68cf59
branch: main
author: Michiel W. Beijen <[email protected]>
committer: encukou <[email protected]>
date: 2026-03-13T14:10:48+01:00
summary:

GH-60729: Add IEEE format wave audio support (GH-145384)


Co-authored-by: Lionel Koenig <[email protected]>

files:
A Lib/test/audiodata/pluck-float32.wav
A Misc/NEWS.d/next/Library/2023-03-10-13-10-06.gh-issue-60729.KCCHTe.rst
M Doc/library/wave.rst
M Doc/whatsnew/3.15.rst
M Lib/test/audiotests.py
M Lib/test/test_wave.py
M Lib/wave.py

diff --git a/Doc/library/wave.rst b/Doc/library/wave.rst
index ff020b52da3f23..9d30a14f112937 100644
--- a/Doc/library/wave.rst
+++ b/Doc/library/wave.rst
@@ -9,14 +9,19 @@
 --------------
 
 The :mod:`!wave` module provides a convenient interface to the Waveform Audio
-"WAVE" (or "WAV") file format. Only uncompressed PCM encoded wave files are
-supported.
+"WAVE" (or "WAV") file format.
+
+The module supports uncompressed PCM and IEEE floating-point WAV formats.
 
 .. versionchanged:: 3.12
 
    Support for ``WAVE_FORMAT_EXTENSIBLE`` headers was added, provided that the
    extended format is ``KSDATAFORMAT_SUBTYPE_PCM``.
 
+.. versionchanged:: next
+
+   Support for reading and writing ``WAVE_FORMAT_IEEE_FLOAT`` files was added.
+
 The :mod:`!wave` module defines the following function and exception:
 
 
@@ -60,6 +65,21 @@ The :mod:`!wave` module defines the following function and 
exception:
    specification or hits an implementation deficiency.
 
 
+.. data:: WAVE_FORMAT_PCM
+
+   Format code for uncompressed PCM audio.
+
+
+.. data:: WAVE_FORMAT_IEEE_FLOAT
+
+   Format code for IEEE floating-point audio.
+
+
+.. data:: WAVE_FORMAT_EXTENSIBLE
+
+   Format code for WAVE extensible headers.
+
+
 .. _wave-read-objects:
 
 Wave_read Objects
@@ -98,6 +118,14 @@ Wave_read Objects
       Returns number of audio frames.
 
 
+   .. method:: getformat()
+
+      Returns the frame format code.
+
+      This is one of :data:`WAVE_FORMAT_PCM`,
+      :data:`WAVE_FORMAT_IEEE_FLOAT`, or :data:`WAVE_FORMAT_EXTENSIBLE`.
+
+
    .. method:: getcomptype()
 
       Returns compression type (``'NONE'`` is the only supported type).
@@ -112,8 +140,8 @@ Wave_read Objects
    .. method:: getparams()
 
       Returns a :func:`~collections.namedtuple` ``(nchannels, sampwidth,
-      framerate, nframes, comptype, compname)``, equivalent to output of the
-      ``get*()`` methods.
+      framerate, nframes, comptype, compname)``, equivalent to output
+      of the ``get*()`` methods.
 
 
    .. method:: readframes(n)
@@ -190,6 +218,9 @@ Wave_write Objects
 
       Set the sample width to *n* bytes.
 
+      For :data:`WAVE_FORMAT_IEEE_FLOAT`, only 4-byte (32-bit) and
+      8-byte (64-bit) sample widths are supported.
+
 
    .. method:: getsampwidth()
 
@@ -238,11 +269,32 @@ Wave_write Objects
       Return the human-readable compression type name.
 
 
+   .. method:: setformat(format)
+
+      Set the frame format code.
+
+      Supported values are :data:`WAVE_FORMAT_PCM` and
+      :data:`WAVE_FORMAT_IEEE_FLOAT`.
+
+      When setting :data:`WAVE_FORMAT_IEEE_FLOAT`, the sample width must be
+      4 or 8 bytes.
+
+
+   .. method:: getformat()
+
+      Return the current frame format code.
+
+
    .. method:: setparams(tuple)
 
-      The *tuple* should be ``(nchannels, sampwidth, framerate, nframes, 
comptype,
-      compname)``, with values valid for the ``set*()`` methods.  Sets all
-      parameters.
+      The *tuple* should be
+      ``(nchannels, sampwidth, framerate, nframes, comptype, compname, 
format)``,
+      with values valid for the ``set*()`` methods. Sets all parameters.
+
+      For backwards compatibility, a 6-item tuple without *format* is also
+      accepted and defaults to :data:`WAVE_FORMAT_PCM`.
+
+      For ``format=WAVE_FORMAT_IEEE_FLOAT``, *sampwidth* must be 4 or 8.
 
 
    .. method:: getparams()
@@ -279,3 +331,6 @@ Wave_write Objects
       Note that it is invalid to set any parameters after calling 
:meth:`writeframes`
       or :meth:`writeframesraw`, and any attempt to do so will raise
       :exc:`wave.Error`.
+
+      For :data:`WAVE_FORMAT_IEEE_FLOAT` output, a ``fact`` chunk is written as
+      required by the WAVE specification for non-PCM formats.
diff --git a/Doc/whatsnew/3.15.rst b/Doc/whatsnew/3.15.rst
index 459846e55ccf70..d5b14216770906 100644
--- a/Doc/whatsnew/3.15.rst
+++ b/Doc/whatsnew/3.15.rst
@@ -1518,6 +1518,21 @@ typing
 wave
 ----
 
+* Added support for IEEE floating-point WAVE audio
+  (``WAVE_FORMAT_IEEE_FLOAT``) in :mod:`wave`.
+
+* Added :meth:`wave.Wave_read.getformat`, :meth:`wave.Wave_write.getformat`,
+  and :meth:`wave.Wave_write.setformat` for explicit frame format handling.
+
+* :meth:`wave.Wave_write.setparams` accepts both 7-item tuples including
+  ``format`` and 6-item tuples for backwards compatibility (defaulting to
+  ``WAVE_FORMAT_PCM``).
+
+* ``WAVE_FORMAT_IEEE_FLOAT`` output now includes a ``fact`` chunk,
+  as required for non-PCM WAVE formats.
+
+(Contributed by Lionel Koenig and Michiel W. Beijen in :gh:`60729`.)
+
 * Removed the ``getmark()``, ``setmark()`` and ``getmarkers()`` methods
   of the :class:`~wave.Wave_read` and :class:`~wave.Wave_write` classes,
   which were deprecated since Python 3.13.
diff --git a/Lib/test/audiodata/pluck-float32.wav 
b/Lib/test/audiodata/pluck-float32.wav
new file mode 100644
index 00000000000000..2030fb16d6e3bd
Binary files /dev/null and b/Lib/test/audiodata/pluck-float32.wav differ
diff --git a/Lib/test/audiotests.py b/Lib/test/audiotests.py
index 9d6c4cc2b4b02c..394097df17dca9 100644
--- a/Lib/test/audiotests.py
+++ b/Lib/test/audiotests.py
@@ -27,17 +27,18 @@ def tearDown(self):
         unlink(TESTFN)
 
     def check_params(self, f, nchannels, sampwidth, framerate, nframes,
-                     comptype, compname):
+                     comptype, compname, format):
         self.assertEqual(f.getnchannels(), nchannels)
         self.assertEqual(f.getsampwidth(), sampwidth)
         self.assertEqual(f.getframerate(), framerate)
         self.assertEqual(f.getnframes(), nframes)
         self.assertEqual(f.getcomptype(), comptype)
         self.assertEqual(f.getcompname(), compname)
+        self.assertEqual(f.getformat(), format)
 
         params = f.getparams()
         self.assertEqual(params,
-                (nchannels, sampwidth, framerate, nframes, comptype, compname))
+            (nchannels, sampwidth, framerate, nframes, comptype, compname))
         self.assertEqual(params.nchannels, nchannels)
         self.assertEqual(params.sampwidth, sampwidth)
         self.assertEqual(params.framerate, framerate)
@@ -51,13 +52,17 @@ def check_params(self, f, nchannels, sampwidth, framerate, 
nframes,
 
 
 class AudioWriteTests(AudioTests):
+    readonly = False
 
     def create_file(self, testfile):
+        if self.readonly:
+            self.skipTest('Read only file format')
         f = self.fout = self.module.open(testfile, 'wb')
         f.setnchannels(self.nchannels)
         f.setsampwidth(self.sampwidth)
         f.setframerate(self.framerate)
         f.setcomptype(self.comptype, self.compname)
+        f.setformat(self.format)
         return f
 
     def check_file(self, testfile, nframes, frames):
@@ -67,13 +72,14 @@ def check_file(self, testfile, nframes, frames):
             self.assertEqual(f.getframerate(), self.framerate)
             self.assertEqual(f.getnframes(), nframes)
             self.assertEqual(f.readframes(nframes), frames)
+            self.assertEqual(f.getformat(), self.format)
 
     def test_write_params(self):
         f = self.create_file(TESTFN)
         f.setnframes(self.nframes)
         f.writeframes(self.frames)
         self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
-                          self.nframes, self.comptype, self.compname)
+                          self.nframes, self.comptype, self.compname, 
self.format)
         f.close()
 
     def test_write_context_manager_calls_close(self):
@@ -257,7 +263,7 @@ def test_read_params(self):
         f = self.f = self.module.open(self.sndfilepath)
         #self.assertEqual(f.getfp().name, self.sndfilepath)
         self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
-                          self.sndfilenframes, self.comptype, self.compname)
+                          self.sndfilenframes, self.comptype, self.compname, 
self.format)
 
     def test_close(self):
         with open(self.sndfilepath, 'rb') as testfile:
@@ -298,6 +304,8 @@ def test_read(self):
             f.setpos(f.getnframes() + 1)
 
     def test_copy(self):
+        if self.readonly:
+            self.skipTest('Read only file format')
         f = self.f = self.module.open(self.sndfilepath)
         fout = self.fout = self.module.open(TESTFN, 'wb')
         fout.setparams(f.getparams())
diff --git a/Lib/test/test_wave.py b/Lib/test/test_wave.py
index 4c21f16553775c..a1afe91e3774b9 100644
--- a/Lib/test/test_wave.py
+++ b/Lib/test/test_wave.py
@@ -1,7 +1,7 @@
 import unittest
 from test import audiotests
 from test import support
-from test.support.os_helper import FakePath
+from test.support.os_helper import FakePath, unlink
 import io
 import os
 import struct
@@ -22,6 +22,7 @@ class WavePCM8Test(WaveTest, unittest.TestCase):
     sampwidth = 1
     framerate = 11025
     nframes = 48
+    format = wave.WAVE_FORMAT_PCM
     comptype = 'NONE'
     compname = 'not compressed'
     frames = bytes.fromhex("""\
@@ -39,6 +40,7 @@ class WavePCM16Test(WaveTest, unittest.TestCase):
     sampwidth = 2
     framerate = 11025
     nframes = 48
+    format = wave.WAVE_FORMAT_PCM
     comptype = 'NONE'
     compname = 'not compressed'
     frames = bytes.fromhex("""\
@@ -60,6 +62,7 @@ class WavePCM24Test(WaveTest, unittest.TestCase):
     sampwidth = 3
     framerate = 11025
     nframes = 48
+    format = wave.WAVE_FORMAT_PCM
     comptype = 'NONE'
     compname = 'not compressed'
     frames = bytes.fromhex("""\
@@ -87,6 +90,8 @@ class WavePCM24ExtTest(WaveTest, unittest.TestCase):
     sampwidth = 3
     framerate = 11025
     nframes = 48
+    format = wave.WAVE_FORMAT_EXTENSIBLE
+    readonly = True  # Writing EXTENSIBLE wave format is not supported.
     comptype = 'NONE'
     compname = 'not compressed'
     frames = bytes.fromhex("""\
@@ -114,6 +119,7 @@ class WavePCM32Test(WaveTest, unittest.TestCase):
     sampwidth = 4
     framerate = 11025
     nframes = 48
+    format = wave.WAVE_FORMAT_PCM
     comptype = 'NONE'
     compname = 'not compressed'
     frames = bytes.fromhex("""\
@@ -134,14 +140,140 @@ class WavePCM32Test(WaveTest, unittest.TestCase):
         frames = wave._byteswap(frames, 4)
 
 
+class WaveIeeeFloatingPointTest(WaveTest, unittest.TestCase):
+    sndfilename = 'pluck-float32.wav'
+    sndfilenframes = 3307
+    nchannels = 2
+    sampwidth = 4
+    framerate = 11025
+    nframes = 48
+    format = wave.WAVE_FORMAT_IEEE_FLOAT
+    comptype = 'NONE'
+    compname = 'not compressed'
+    frames = bytes.fromhex("""\
+      60598B3C001423BA 1FB4163F8054FA3B 0E4FC43E80C51D3D 53467EBF4030843D \
+      FC84D0BE304C563D 3053113F40BEFC3C B72F00BFC03E583C E0FEDA3C805142BC \
+      54510FBFE02638BD 569F16BF40FDCABD C060A63EECA421BE 3CE5523E2C3349BE \
+      0C2E10BE14725BBE 5268E7BEDC3B6CBE 985AE03D80497ABE B4B606BEECB67EBE \
+      B0B12E3FC87C6CBE 005519BD4C0F3EBE F8BD1B3EECDF03BE 924E9FBE588D8DBD \
+      D4E150BF501711BD B079A0BD20FBFBBC 5863863D40760CBD 0E3C83BE40E217BD \
+      04FF0B3EF07839BD E29AFB3E80A714BD B91007BFE042D3BC B5AD4D3F80CDA0BB \
+      1AB1C3BEB04E023D D33A063FC0A8973D 8012F9BEE074EC3D 7341223FD415153E \
+      D80409BE04A63A3E 00F27BBFBC25333E 0000803FFC29223E 000080BF38A7143E \
+      3638133F283BEB3D 7C6E253F00CADB3D 686A02BE88FDF53D 920CC7BE28E1FB3D \
+      185B5ABED8A2CE3D 5189463FC8A7A53D E88F8C3DF0FFA13D 1CE6AE3EE0A0B03D \
+      DF90223F184EE43D 376768BF2CD8093E 281612BF60B3EE3D 2F26083F88B4A53D \
+      """)
+
 class MiscTestCase(unittest.TestCase):
     def test__all__(self):
-        not_exported = {'WAVE_FORMAT_PCM', 'WAVE_FORMAT_EXTENSIBLE', 
'KSDATAFORMAT_SUBTYPE_PCM'}
+        not_exported = {'KSDATAFORMAT_SUBTYPE_PCM'}
         support.check__all__(self, wave, not_exported=not_exported)
 
 
 class WaveLowLevelTest(unittest.TestCase):
 
+    def test_setparams_6_tuple_defaults_to_pcm(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+            w.setparams((1, 2, 22050, 0, 'NONE', 'not compressed'))
+            self.assertEqual(w.getformat(), wave.WAVE_FORMAT_PCM)
+
+    def test_setparams_7_tuple_uses_format(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
+                         wave.WAVE_FORMAT_IEEE_FLOAT))
+            self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
+
+    def test_setparams_7_tuple_ieee_64bit_sampwidth(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setparams((1, 8, 22050, 0, 'NONE', 'not compressed',
+                         wave.WAVE_FORMAT_IEEE_FLOAT))
+            self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
+            self.assertEqual(w.getsampwidth(), 8)
+
+    def test_getparams_backward_compatible_shape(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setparams((1, 4, 22050, 0, 'NONE', 'not compressed',
+                         wave.WAVE_FORMAT_IEEE_FLOAT))
+            params = w.getparams()
+            self.assertEqual(params, (1, 4, 22050, 0, 'NONE', 'not 
compressed'))
+
+    def test_getformat_setformat(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setsampwidth(4)
+            w.setframerate(22050)
+            self.assertEqual(w.getformat(), wave.WAVE_FORMAT_PCM)
+            w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+            self.assertEqual(w.getformat(), wave.WAVE_FORMAT_IEEE_FLOAT)
+
+    def test_setformat_ieee_requires_32_or_64_bit_sampwidth(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setsampwidth(2)
+            w.setframerate(22050)
+            with self.assertRaisesRegex(wave.Error,
+                                        'unsupported sample width for IEEE 
float format'):
+                w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+
+    def test_setsampwidth_ieee_requires_32_or_64_bit(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setframerate(22050)
+            w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+            with self.assertRaisesRegex(wave.Error,
+                                        'unsupported sample width for IEEE 
float format'):
+                w.setsampwidth(2)
+            w.setsampwidth(4)
+
+    def test_setsampwidth_ieee_accepts_64_bit(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setframerate(22050)
+            w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+            w.setsampwidth(8)
+            self.assertEqual(w.getsampwidth(), 8)
+
+    def test_read_getformat(self):
+        b = b'RIFF' + struct.pack('<L', 36) + b'WAVE'
+        b += b'fmt ' + struct.pack('<LHHLLHH', 16, 1, 1, 11025, 11025, 1, 8)
+        b += b'data' + struct.pack('<L', 0)
+        with wave.open(io.BytesIO(b), 'rb') as r:
+            self.assertEqual(r.getformat(), wave.WAVE_FORMAT_PCM)
+
     def test_read_no_chunks(self):
         b = b'SPAM'
         with self.assertRaises(EOFError):
@@ -207,6 +339,58 @@ def test_open_in_write_raises(self):
             support.gc_collect()
             self.assertIsNone(cm.unraisable)
 
+    def test_ieee_float_has_fact_chunk(self):
+        nframes = 100
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setsampwidth(4)
+            w.setframerate(22050)
+            w.setformat(wave.WAVE_FORMAT_IEEE_FLOAT)
+            w.writeframes(b'\x00\x00\x00\x00' * nframes)
+
+        with open(filename, 'rb') as f:
+            f.read(12)
+            fact_found = False
+            fact_samples = None
+            while True:
+                chunk_id = f.read(4)
+                if len(chunk_id) < 4:
+                    break
+                chunk_size = struct.unpack('<L', f.read(4))[0]
+                if chunk_id == b'fact':
+                    fact_found = True
+                    fact_samples = struct.unpack('<L', f.read(4))[0]
+                    break
+                f.seek(chunk_size + (chunk_size & 1), 1)
+
+        self.assertTrue(fact_found)
+        self.assertEqual(fact_samples, nframes)
+
+    def test_pcm_has_no_fact_chunk(self):
+        with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
+            filename = fp.name
+        self.addCleanup(unlink, filename)
+
+        with wave.open(filename, 'wb') as w:
+            w.setnchannels(1)
+            w.setsampwidth(2)
+            w.setframerate(22050)
+            w.writeframes(b'\x00\x00' * 100)
+
+        with open(filename, 'rb') as f:
+            f.read(12)
+            while True:
+                chunk_id = f.read(4)
+                if len(chunk_id) < 4:
+                    break
+                chunk_size = struct.unpack('<L', f.read(4))[0]
+                self.assertNotEqual(chunk_id, b'fact')
+                f.seek(chunk_size + (chunk_size & 1), 1)
+
 
 class WaveOpen(unittest.TestCase):
     def test_open_pathlike(self):
diff --git a/Lib/wave.py b/Lib/wave.py
index 25ca9ef168e8a5..92e1f34c356161 100644
--- a/Lib/wave.py
+++ b/Lib/wave.py
@@ -15,6 +15,8 @@
       getsampwidth()  -- returns sample width in bytes
       getframerate()  -- returns sampling frequency
       getnframes()    -- returns number of audio frames
+      getformat()     -- returns frame encoding (WAVE_FORMAT_PCM, 
WAVE_FORMAT_IEEE_FLOAT
+                         or WAVE_FORMAT_EXTENSIBLE)
       getcomptype()   -- returns compression type ('NONE' for linear samples)
       getcompname()   -- returns human-readable version of
                          compression type ('not compressed' linear samples)
@@ -42,6 +44,9 @@
       setsampwidth(n) -- set the sample width
       setframerate(n) -- set the frame rate
       setnframes(n)   -- set the number of frames
+      setformat(format)
+                      -- set the frame format. Only WAVE_FORMAT_PCM and
+                         WAVE_FORMAT_IEEE_FLOAT are supported.
       setcomptype(type, name)
                       -- set the compression type and the
                          human-readable compression type
@@ -74,12 +79,21 @@
 import sys
 
 
-__all__ = ["open", "Error", "Wave_read", "Wave_write"]
+__all__ = [
+    "open",
+    "Error",
+    "Wave_read",
+    "Wave_write",
+    "WAVE_FORMAT_PCM",
+    "WAVE_FORMAT_IEEE_FLOAT",
+    "WAVE_FORMAT_EXTENSIBLE",
+]
 
 class Error(Exception):
     pass
 
 WAVE_FORMAT_PCM = 0x0001
+WAVE_FORMAT_IEEE_FLOAT = 0x0003
 WAVE_FORMAT_EXTENSIBLE = 0xFFFE
 # Derived from uuid.UUID("00000001-0000-0010-8000-00aa00389b71").bytes_le
 KSDATAFORMAT_SUBTYPE_PCM = 
b'\x01\x00\x00\x00\x00\x00\x10\x00\x80\x00\x00\xaa\x008\x9bq'
@@ -226,6 +240,10 @@ class Wave_read:
               available through the getsampwidth() method
     _framerate -- the sampling frequency
               available through the getframerate() method
+    _format -- frame format
+              One of WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT
+              or WAVE_FORMAT_EXTENSIBLE available through
+              getformat() method
     _comptype -- the AIFF-C compression type ('NONE' if AIFF)
               available through the getcomptype() method
     _compname -- the human-readable AIFF-C compression type
@@ -327,6 +345,9 @@ def getsampwidth(self):
     def getframerate(self):
         return self._framerate
 
+    def getformat(self):
+        return self._format
+
     def getcomptype(self):
         return self._comptype
 
@@ -367,16 +388,16 @@ def readframes(self, nframes):
 
     def _read_fmt_chunk(self, chunk):
         try:
-            wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, 
wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
+            self._format, self._nchannels, self._framerate, dwAvgBytesPerSec, 
wBlockAlign = struct.unpack_from('<HHLLH', chunk.read(14))
         except struct.error:
             raise EOFError from None
-        if wFormatTag != WAVE_FORMAT_PCM and wFormatTag != 
WAVE_FORMAT_EXTENSIBLE:
-            raise Error('unknown format: %r' % (wFormatTag,))
+        if self._format not in (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT, 
WAVE_FORMAT_EXTENSIBLE):
+            raise Error('unknown format: %r' % (self._format,))
         try:
             sampwidth = struct.unpack_from('<H', chunk.read(2))[0]
         except struct.error:
             raise EOFError from None
-        if wFormatTag == WAVE_FORMAT_EXTENSIBLE:
+        if self._format == WAVE_FORMAT_EXTENSIBLE:
             try:
                 cbSize, wValidBitsPerSample, dwChannelMask = 
struct.unpack_from('<HHL', chunk.read(8))
                 # Read the entire UUID from the chunk
@@ -419,6 +440,8 @@ class Wave_write:
               set through the setsampwidth() or setparams() method
     _framerate -- the sampling frequency
               set through the setframerate() or setparams() method
+    _format -- frame format
+              set through setformat() method
     _nframes -- the number of audio frames written to the header
               set through the setnframes() or setparams() method
 
@@ -446,12 +469,14 @@ def initfp(self, file):
         self._file = file
         self._convert = None
         self._nchannels = 0
+        self._format = WAVE_FORMAT_PCM
         self._sampwidth = 0
         self._framerate = 0
         self._nframes = 0
         self._nframeswritten = 0
         self._datawritten = 0
         self._datalength = 0
+        self._fact_sample_count_pos = None
         self._headerwritten = False
 
     def __del__(self):
@@ -481,7 +506,10 @@ def getnchannels(self):
     def setsampwidth(self, sampwidth):
         if self._datawritten:
             raise Error('cannot change parameters after starting to write')
-        if sampwidth < 1 or sampwidth > 4:
+        if self._format == WAVE_FORMAT_IEEE_FLOAT:
+            if sampwidth not in (4, 8):
+                raise Error('unsupported sample width for IEEE float format')
+        elif sampwidth < 1 or sampwidth > 4:
             raise Error('bad sample width')
         self._sampwidth = sampwidth
 
@@ -518,6 +546,18 @@ def setcomptype(self, comptype, compname):
         self._comptype = comptype
         self._compname = compname
 
+    def setformat(self, format):
+        if self._datawritten:
+            raise Error('cannot change parameters after starting to write')
+        if format not in (WAVE_FORMAT_IEEE_FLOAT, WAVE_FORMAT_PCM):
+            raise Error('unsupported wave format')
+        if format == WAVE_FORMAT_IEEE_FLOAT and self._sampwidth and 
self._sampwidth not in (4, 8):
+            raise Error('unsupported sample width for IEEE float format')
+        self._format = format
+
+    def getformat(self):
+        return self._format
+
     def getcomptype(self):
         return self._comptype
 
@@ -525,10 +565,15 @@ def getcompname(self):
         return self._compname
 
     def setparams(self, params):
-        nchannels, sampwidth, framerate, nframes, comptype, compname = params
         if self._datawritten:
             raise Error('cannot change parameters after starting to write')
+        if len(params) == 6:
+            nchannels, sampwidth, framerate, nframes, comptype, compname = 
params
+            format = WAVE_FORMAT_PCM
+        else:
+            nchannels, sampwidth, framerate, nframes, comptype, compname, 
format = params
         self.setnchannels(nchannels)
+        self.setformat(format)
         self.setsampwidth(sampwidth)
         self.setframerate(framerate)
         self.setnframes(nframes)
@@ -589,6 +634,9 @@ def _ensure_header_written(self, datasize):
                 raise Error('sampling rate not specified')
             self._write_header(datasize)
 
+    def _needs_fact_chunk(self):
+        return self._format == WAVE_FORMAT_IEEE_FLOAT
+
     def _write_header(self, initlength):
         assert not self._headerwritten
         self._file.write(b'RIFF')
@@ -599,12 +647,23 @@ def _write_header(self, initlength):
             self._form_length_pos = self._file.tell()
         except (AttributeError, OSError):
             self._form_length_pos = None
-        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
-            36 + self._datalength, b'WAVE', b'fmt ', 16,
-            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
+        has_fact = self._needs_fact_chunk()
+        header_overhead = 36 + (12 if has_fact else 0)
+        self._file.write(struct.pack('<L4s4sLHHLLHH',
+            header_overhead + self._datalength, b'WAVE', b'fmt ', 16,
+            self._format, self._nchannels, self._framerate,
             self._nchannels * self._framerate * self._sampwidth,
             self._nchannels * self._sampwidth,
-            self._sampwidth * 8, b'data'))
+            self._sampwidth * 8))
+        if has_fact:
+            self._file.write(b'fact')
+            self._file.write(struct.pack('<L', 4))
+            try:
+                self._fact_sample_count_pos = self._file.tell()
+            except (AttributeError, OSError):
+                self._fact_sample_count_pos = None
+            self._file.write(struct.pack('<L', self._nframes))
+        self._file.write(b'data')
         if self._form_length_pos is not None:
             self._data_length_pos = self._file.tell()
         self._file.write(struct.pack('<L', self._datalength))
@@ -615,8 +674,13 @@ def _patchheader(self):
         if self._datawritten == self._datalength:
             return
         curpos = self._file.tell()
+        header_overhead = 36 + (12 if self._needs_fact_chunk() else 0)
         self._file.seek(self._form_length_pos, 0)
-        self._file.write(struct.pack('<L', 36 + self._datawritten))
+        self._file.write(struct.pack('<L', header_overhead + 
self._datawritten))
+        if self._fact_sample_count_pos is not None:
+            self._file.seek(self._fact_sample_count_pos, 0)
+            nframes = self._datawritten // (self._nchannels * self._sampwidth)
+            self._file.write(struct.pack('<L', nframes))
         self._file.seek(self._data_length_pos, 0)
         self._file.write(struct.pack('<L', self._datawritten))
         self._file.seek(curpos, 0)
diff --git 
a/Misc/NEWS.d/next/Library/2023-03-10-13-10-06.gh-issue-60729.KCCHTe.rst 
b/Misc/NEWS.d/next/Library/2023-03-10-13-10-06.gh-issue-60729.KCCHTe.rst
new file mode 100644
index 00000000000000..82876cd81e4a19
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-03-10-13-10-06.gh-issue-60729.KCCHTe.rst
@@ -0,0 +1 @@
+Add support for floating point audio wave files in :mod:`wave`.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3//lists/python-checkins.python.org
Member address: [email protected]

Reply via email to