From 6a0d088cbe7d4ece3762cb867e3d6588d18ad8f6 Mon Sep 17 00:00:00 2001
From: Jacob Champion <jacob.champion@enterprisedb.com>
Date: Thu, 25 Apr 2024 15:26:40 -0700
Subject: [PATCH v4 1/2] jsonapi: add lexer option to keep token ownership

Commit 0785d1b8b adds support for libpq as a JSON client, but
allocations for string tokens can still be leaked during parsing
failures. This is tricky to fix for the object_field semantic callbacks:
the field name must remain valid until the end of the object, but if a
parsing error is encountered partway through, object_field_end() won't
be invoked and the client won't get a chance to free the field name.

At Andrew's suggestion, add a flag to switch the ownership of parsed
tokens to the lexer. When this is enabled, the client must make a copy
of any tokens it wants to persist past the callback lifetime, but the
lexer will handle necessary cleanup on failure.

A -o option has been added to test_json_parser_incremental to exercise
the new setJsonLexContextOwnsTokens() API, and the test_json_parser TAP
tests make use of it. (The test program now cleans up allocated memory,
so that tests can be usefully run under leak sanitizers.)
---
 src/common/jsonapi.c                          | 102 ++++++++++++++++--
 src/include/common/jsonapi.h                  |  28 ++++-
 .../t/001_test_json_parser_incremental.pl     |  13 ++-
 .../modules/test_json_parser/t/002_inline.pl  |  15 +--
 .../test_json_parser/t/003_test_semantic.pl   |  11 +-
 .../test_json_parser_incremental.c            |  37 +++++--
 6 files changed, 173 insertions(+), 33 deletions(-)

diff --git a/src/common/jsonapi.c b/src/common/jsonapi.c
index 45838d8a18..df6e633b5e 100644
--- a/src/common/jsonapi.c
+++ b/src/common/jsonapi.c
@@ -161,6 +161,7 @@ struct JsonParserStack
  */
 struct JsonIncrementalState
 {
+	bool		started;
 	bool		is_last_chunk;
 	bool		partial_completed;
 	jsonapi_StrValType partial_token;
@@ -280,6 +281,7 @@ static JsonParseErrorType parse_array_element(JsonLexContext *lex, const JsonSem
 static JsonParseErrorType parse_array(JsonLexContext *lex, const JsonSemAction *sem);
 static JsonParseErrorType report_parse_error(JsonParseContext ctx, JsonLexContext *lex);
 static bool allocate_incremental_state(JsonLexContext *lex);
+static inline void set_fname(JsonLexContext *lex, char *fname);
 
 /* the null action object used for pure validation */
 const JsonSemAction nullSemAction =
@@ -437,7 +439,7 @@ allocate_incremental_state(JsonLexContext *lex)
 			   *fnull;
 
 	lex->inc_state = ALLOC0(sizeof(JsonIncrementalState));
-	pstack = ALLOC(sizeof(JsonParserStack));
+	pstack = ALLOC0(sizeof(JsonParserStack));
 	prediction = ALLOC(JS_STACK_CHUNK_SIZE * JS_MAX_PROD_LEN);
 	fnames = ALLOC(JS_STACK_CHUNK_SIZE * sizeof(char *));
 	fnull = ALLOC(JS_STACK_CHUNK_SIZE * sizeof(bool));
@@ -464,10 +466,17 @@ allocate_incremental_state(JsonLexContext *lex)
 	lex->pstack = pstack;
 	lex->pstack->stack_size = JS_STACK_CHUNK_SIZE;
 	lex->pstack->prediction = prediction;
-	lex->pstack->pred_index = 0;
 	lex->pstack->fnames = fnames;
 	lex->pstack->fnull = fnull;
 
+	/*
+	 * fnames between 0 and lex_level must always be defined so that
+	 * freeJsonLexContext() can handle them safely. inc/dec_lex_level() handle
+	 * the rest.
+	 */
+	Assert(lex->lex_level == 0);
+	lex->pstack->fnames[0] = NULL;
+
 	lex->incremental = true;
 	return true;
 }
@@ -530,6 +539,25 @@ makeJsonLexContextIncremental(JsonLexContext *lex, int encoding,
 	return lex;
 }
 
+void
+setJsonLexContextOwnsTokens(JsonLexContext *lex, bool owned_by_context)
+{
+	if (lex->incremental && lex->inc_state->started)
+	{
+		/*
+		 * Switching this flag after parsing has already started is a
+		 * programming error.
+		 */
+		Assert(false);
+		return;
+	}
+
+	if (owned_by_context)
+		lex->flags |= JSONLEX_CTX_OWNS_TOKENS;
+	else
+		lex->flags &= ~JSONLEX_CTX_OWNS_TOKENS;
+}
+
 static inline bool
 inc_lex_level(JsonLexContext *lex)
 {
@@ -569,12 +597,23 @@ inc_lex_level(JsonLexContext *lex)
 	}
 
 	lex->lex_level += 1;
+
+	if (lex->incremental)
+	{
+		/*
+		 * Ensure freeJsonLexContext() remains safe even if no fname is
+		 * assigned at this level.
+		 */
+		lex->pstack->fnames[lex->lex_level] = NULL;
+	}
+
 	return true;
 }
 
 static inline void
 dec_lex_level(JsonLexContext *lex)
 {
+	set_fname(lex, NULL);		/* free the current level's fname, if needed */
 	lex->lex_level -= 1;
 }
 
@@ -608,6 +647,15 @@ have_prediction(JsonParserStack *pstack)
 static inline void
 set_fname(JsonLexContext *lex, char *fname)
 {
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+	{
+		/*
+		 * Don't leak prior fnames. If one hasn't been assigned yet,
+		 * inc_lex_level ensured that it's NULL (and therefore safe to free).
+		 */
+		FREE(lex->pstack->fnames[lex->lex_level]);
+	}
+
 	lex->pstack->fnames[lex->lex_level] = fname;
 }
 
@@ -655,8 +703,19 @@ freeJsonLexContext(JsonLexContext *lex)
 		jsonapi_termStringInfo(&lex->inc_state->partial_token);
 		FREE(lex->inc_state);
 		FREE(lex->pstack->prediction);
+
+		if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		{
+			int			i;
+
+			/* Clean up any tokens that were left behind. */
+			for (i = 0; i <= lex->lex_level; i++)
+				FREE(lex->pstack->fnames[i]);
+		}
+
 		FREE(lex->pstack->fnames);
 		FREE(lex->pstack->fnull);
+		FREE(lex->pstack->scalar_val);
 		FREE(lex->pstack);
 	}
 
@@ -826,6 +885,7 @@ pg_parse_json_incremental(JsonLexContext *lex,
 	lex->input = lex->token_terminator = lex->line_start = json;
 	lex->input_length = len;
 	lex->inc_state->is_last_chunk = is_last;
+	lex->inc_state->started = true;
 
 	/* get the initial token */
 	result = json_lex(lex);
@@ -1086,6 +1146,17 @@ pg_parse_json_incremental(JsonLexContext *lex,
 						if (sfunc != NULL)
 						{
 							result = (*sfunc) (sem->semstate, pstack->scalar_val, pstack->scalar_tok);
+
+							/*
+							 * Either ownership of the token passed to the
+							 * callback, or we need to free it now. Either
+							 * way, clear our pointer to it so it doesn't get
+							 * freed in the future.
+							 */
+							if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+								FREE(pstack->scalar_val);
+							pstack->scalar_val = NULL;
+
 							if (result != JSON_SUCCESS)
 								return result;
 						}
@@ -1221,11 +1292,17 @@ parse_scalar(JsonLexContext *lex, const JsonSemAction *sem)
 	/* consume the token */
 	result = json_lex(lex);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(val);
 		return result;
+	}
 
-	/* invoke the callback */
+	/* invoke the callback, which may take ownership of val */
 	result = (*sfunc) (sem->semstate, val, tok);
 
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		FREE(val);
+
 	return result;
 }
 
@@ -1238,7 +1315,7 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	 * generally call a field name a "key".
 	 */
 
-	char	   *fname = NULL;	/* keep compiler quiet */
+	char	   *fname = NULL;
 	json_ofield_action ostart = sem->object_field_start;
 	json_ofield_action oend = sem->object_field_end;
 	bool		isnull;
@@ -1255,11 +1332,17 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	}
 	result = json_lex(lex);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(fname);
 		return result;
+	}
 
 	result = lex_expect(JSON_PARSE_OBJECT_LABEL, lex, JSON_TOKEN_COLON);
 	if (result != JSON_SUCCESS)
+	{
+		FREE(fname);
 		return result;
+	}
 
 	tok = lex_peek(lex);
 	isnull = tok == JSON_TOKEN_NULL;
@@ -1268,7 +1351,7 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 	{
 		result = (*ostart) (sem->semstate, fname, isnull);
 		if (result != JSON_SUCCESS)
-			return result;
+			goto ofield_cleanup;
 	}
 
 	switch (tok)
@@ -1283,16 +1366,19 @@ parse_object_field(JsonLexContext *lex, const JsonSemAction *sem)
 			result = parse_scalar(lex, sem);
 	}
 	if (result != JSON_SUCCESS)
-		return result;
+		goto ofield_cleanup;
 
 	if (oend != NULL)
 	{
 		result = (*oend) (sem->semstate, fname, isnull);
 		if (result != JSON_SUCCESS)
-			return result;
+			goto ofield_cleanup;
 	}
 
-	return JSON_SUCCESS;
+ofield_cleanup:
+	if (lex->flags & JSONLEX_CTX_OWNS_TOKENS)
+		FREE(fname);
+	return result;
 }
 
 static JsonParseErrorType
diff --git a/src/include/common/jsonapi.h b/src/include/common/jsonapi.h
index c524ff5be8..167615a557 100644
--- a/src/include/common/jsonapi.h
+++ b/src/include/common/jsonapi.h
@@ -92,9 +92,11 @@ typedef struct JsonIncrementalState JsonIncrementalState;
  * conjunction with token_start.
  *
  * JSONLEX_FREE_STRUCT/STRVAL are used to drive freeJsonLexContext.
+ * JSONLEX_CTX_OWNS_TOKENS is used by setJsonLexContextOwnsTokens.
  */
 #define JSONLEX_FREE_STRUCT			(1 << 0)
 #define JSONLEX_FREE_STRVAL			(1 << 1)
+#define JSONLEX_CTX_OWNS_TOKENS		(1 << 2)
 typedef struct JsonLexContext
 {
 	const char *input;
@@ -130,9 +132,10 @@ typedef JsonParseErrorType (*json_scalar_action) (void *state, char *token, Json
  * to doing a pure parse with no side-effects, and is therefore exactly
  * what the json input routines do.
  *
- * The 'fname' and 'token' strings passed to these actions are palloc'd.
- * They are not free'd or used further by the parser, so the action function
- * is free to do what it wishes with them.
+ * By default, the 'fname' and 'token' strings passed to these actions are
+ * palloc'd.  They are not free'd or used further by the parser, so the action
+ * function is free to do what it wishes with them. This behavior may be
+ * modified by setJsonLexContextOwnsTokens().
  *
  * All action functions return JsonParseErrorType.  If the result isn't
  * JSON_SUCCESS, the parse is abandoned and that error code is returned.
@@ -216,6 +219,25 @@ extern JsonLexContext *makeJsonLexContextIncremental(JsonLexContext *lex,
 													 int encoding,
 													 bool need_escapes);
 
+/*
+ * Sets whether tokens passed to semantic action callbacks are owned by the
+ * context (in which case, the callback must duplicate the tokens for long-term
+ * storage) or by the callback (in which case, the callback must explicitly
+ * free tokens to avoid leaks).
+ *
+ * By default, this setting is false: the callback owns the tokens that are
+ * passed to it (and if parsing fails between the two object-field callbacks,
+ * the field name token will likely leak). If set to true, tokens will be freed
+ * by the lexer after the callback completes.
+ *
+ * Setting this to true is important for long-lived clients (such as libpq)
+ * that must not leak memory during a parse failure. For a server backend using
+ * memory contexts, or a client application which will exit on parse failure,
+ * this setting is less critical.
+ */
+extern void setJsonLexContextOwnsTokens(JsonLexContext *lex,
+										bool owned_by_context);
+
 extern void freeJsonLexContext(JsonLexContext *lex);
 
 /* lex one token */
diff --git a/src/test/modules/test_json_parser/t/001_test_json_parser_incremental.pl b/src/test/modules/test_json_parser/t/001_test_json_parser_incremental.pl
index 8cc42e8e29..0c663b8e68 100644
--- a/src/test/modules/test_json_parser/t/001_test_json_parser_incremental.pl
+++ b/src/test/modules/test_json_parser/t/001_test_json_parser_incremental.pl
@@ -13,21 +13,24 @@ use FindBin;
 
 my $test_file = "$FindBin::RealBin/../tiny.json";
 
-my @exes =
-  ("test_json_parser_incremental", "test_json_parser_incremental_shlib");
+my @exes = (
+	[ "test_json_parser_incremental", ],
+	[ "test_json_parser_incremental", "-o", ],
+	[ "test_json_parser_incremental_shlib", ],
+	[ "test_json_parser_incremental_shlib", "-o", ]);
 
 foreach my $exe (@exes)
 {
-	note "testing executable $exe";
+	note "testing executable @$exe";
 
 	# Test the  usage error
-	my ($stdout, $stderr) = run_command([ $exe, "-c", 10 ]);
+	my ($stdout, $stderr) = run_command([ @$exe, "-c", 10 ]);
 	like($stderr, qr/Usage:/, 'error message if not enough arguments');
 
 	# Test that we get success for small chunk sizes from 64 down to 1.
 	for (my $size = 64; $size > 0; $size--)
 	{
-		($stdout, $stderr) = run_command([ $exe, "-c", $size, $test_file ]);
+		($stdout, $stderr) = run_command([ @$exe, "-c", $size, $test_file ]);
 
 		like($stdout, qr/SUCCESS/, "chunk size $size: test succeeds");
 		is($stderr, "", "chunk size $size: no error output");
diff --git a/src/test/modules/test_json_parser/t/002_inline.pl b/src/test/modules/test_json_parser/t/002_inline.pl
index 5b6c6dc4ae..71c462b319 100644
--- a/src/test/modules/test_json_parser/t/002_inline.pl
+++ b/src/test/modules/test_json_parser/t/002_inline.pl
@@ -13,7 +13,7 @@ use Test::More;
 use File::Temp qw(tempfile);
 
 my $dir = PostgreSQL::Test::Utils::tempdir;
-my $exe;
+my @exe;
 
 sub test
 {
@@ -35,7 +35,7 @@ sub test
 
 	foreach my $size (reverse(1 .. $chunk))
 	{
-		my ($stdout, $stderr) = run_command([ $exe, "-c", $size, $fname ]);
+		my ($stdout, $stderr) = run_command([ @exe, "-c", $size, $fname ]);
 
 		if (defined($params{error}))
 		{
@@ -53,13 +53,16 @@ sub test
 	}
 }
 
-my @exes =
-  ("test_json_parser_incremental", "test_json_parser_incremental_shlib");
+my @exes = (
+	[ "test_json_parser_incremental", ],
+	[ "test_json_parser_incremental", "-o", ],
+	[ "test_json_parser_incremental_shlib", ],
+	[ "test_json_parser_incremental_shlib", "-o", ]);
 
 foreach (@exes)
 {
-	$exe = $_;
-	note "testing executable $exe";
+	@exe = @$_;
+	note "testing executable @exe";
 
 	test("number", "12345");
 	test("string", '"hello"');
diff --git a/src/test/modules/test_json_parser/t/003_test_semantic.pl b/src/test/modules/test_json_parser/t/003_test_semantic.pl
index c11480172d..c57ccdb660 100644
--- a/src/test/modules/test_json_parser/t/003_test_semantic.pl
+++ b/src/test/modules/test_json_parser/t/003_test_semantic.pl
@@ -16,14 +16,17 @@ use File::Temp qw(tempfile);
 my $test_file = "$FindBin::RealBin/../tiny.json";
 my $test_out = "$FindBin::RealBin/../tiny.out";
 
-my @exes =
-  ("test_json_parser_incremental", "test_json_parser_incremental_shlib");
+my @exes = (
+	[ "test_json_parser_incremental", ],
+	[ "test_json_parser_incremental", "-o", ],
+	[ "test_json_parser_incremental_shlib", ],
+	[ "test_json_parser_incremental_shlib", "-o", ]);
 
 foreach my $exe (@exes)
 {
-	note "testing executable $exe";
+	note "testing executable @$exe";
 
-	my ($stdout, $stderr) = run_command([ $exe, "-s", $test_file ]);
+	my ($stdout, $stderr) = run_command([ @$exe, "-s", $test_file ]);
 
 	is($stderr, "", "no error output");
 
diff --git a/src/test/modules/test_json_parser/test_json_parser_incremental.c b/src/test/modules/test_json_parser/test_json_parser_incremental.c
index 294e5f74ea..0b02b5203b 100644
--- a/src/test/modules/test_json_parser/test_json_parser_incremental.c
+++ b/src/test/modules/test_json_parser/test_json_parser_incremental.c
@@ -18,6 +18,10 @@
  * If the -s flag is given, the program does semantic processing. This should
  * just mirror back the json, albeit with white space changes.
  *
+ * If the -o flag is given, the JSONLEX_CTX_OWNS_TOKENS flag is set. (This can
+ * be used in combination with a leak sanitizer; without the option, the parser
+ * may leak memory with invalid JSON.)
+ *
  * The argument specifies the file containing the JSON input.
  *
  *-------------------------------------------------------------------------
@@ -72,6 +76,8 @@ static JsonSemAction sem = {
 	.scalar = do_scalar
 };
 
+static bool lex_owns_tokens = false;
+
 int
 main(int argc, char **argv)
 {
@@ -88,10 +94,11 @@ main(int argc, char **argv)
 	char	   *testfile;
 	int			c;
 	bool		need_strings = false;
+	int			ret = 0;
 
 	pg_logging_init(argv[0]);
 
-	while ((c = getopt(argc, argv, "c:s")) != -1)
+	while ((c = getopt(argc, argv, "c:os")) != -1)
 	{
 		switch (c)
 		{
@@ -100,6 +107,9 @@ main(int argc, char **argv)
 				if (chunk_size > BUFSIZE)
 					pg_fatal("chunk size cannot exceed %d", BUFSIZE);
 				break;
+			case 'o':			/* switch token ownership */
+				lex_owns_tokens = true;
+				break;
 			case 's':			/* do semantic processing */
 				testsem = &sem;
 				sem.semstate = palloc(sizeof(struct DoState));
@@ -112,7 +122,7 @@ main(int argc, char **argv)
 
 	if (optind < argc)
 	{
-		testfile = pg_strdup(argv[optind]);
+		testfile = argv[optind];
 		optind++;
 	}
 	else
@@ -122,6 +132,7 @@ main(int argc, char **argv)
 	}
 
 	makeJsonLexContextIncremental(&lex, PG_UTF8, need_strings);
+	setJsonLexContextOwnsTokens(&lex, lex_owns_tokens);
 	initStringInfo(&json);
 
 	if ((json_file = fopen(testfile, PG_BINARY_R)) == NULL)
@@ -160,7 +171,8 @@ main(int argc, char **argv)
 			if (result != JSON_INCOMPLETE)
 			{
 				fprintf(stderr, "%s\n", json_errdetail(result, &lex));
-				exit(1);
+				ret = 1;
+				goto cleanup;
 			}
 			resetStringInfo(&json);
 		}
@@ -172,15 +184,21 @@ main(int argc, char **argv)
 			if (result != JSON_SUCCESS)
 			{
 				fprintf(stderr, "%s\n", json_errdetail(result, &lex));
-				exit(1);
+				ret = 1;
+				goto cleanup;
 			}
 			if (!need_strings)
 				printf("SUCCESS!\n");
 			break;
 		}
 	}
+
+cleanup:
 	fclose(json_file);
-	exit(0);
+	freeJsonLexContext(&lex);
+	free(json.data);
+
+	return ret;
 }
 
 /*
@@ -230,7 +248,8 @@ do_object_field_start(void *state, char *fname, bool isnull)
 static JsonParseErrorType
 do_object_field_end(void *state, char *fname, bool isnull)
 {
-	/* nothing to do really */
+	if (!lex_owns_tokens)
+		free(fname);
 
 	return JSON_SUCCESS;
 }
@@ -291,6 +310,9 @@ do_scalar(void *state, char *token, JsonTokenType tokentype)
 	else
 		printf("%s", token);
 
+	if (!lex_owns_tokens)
+		free(token);
+
 	return JSON_SUCCESS;
 }
 
@@ -343,7 +365,8 @@ usage(const char *progname)
 {
 	fprintf(stderr, "Usage: %s [OPTION ...] testfile\n", progname);
 	fprintf(stderr, "Options:\n");
-	fprintf(stderr, "  -c chunksize      size of piece fed to parser (default 64)n");
+	fprintf(stderr, "  -c chunksize      size of piece fed to parser (default 64)\n");
+	fprintf(stderr, "  -o                set JSONLEX_CTX_OWNS_TOKENS for leak checking\n");
 	fprintf(stderr, "  -s                do semantic processing\n");
 
 }
-- 
2.34.1

