diff --git a/ext/json/ext/generator/generator.c b/ext/json/ext/generator/generator.c index d5c8bfd4..c3a48a19 100644 --- a/ext/json/ext/generator/generator.c +++ b/ext/json/ext/generator/generator.c @@ -12,6 +12,7 @@ typedef struct JSON_Generator_StateStruct { VALUE space_before; VALUE object_nl; VALUE array_nl; + VALUE as_json; long max_nesting; long depth; @@ -30,8 +31,8 @@ typedef struct JSON_Generator_StateStruct { static VALUE mJSON, cState, mString_Extend, eGeneratorError, eNestingError, Encoding_UTF_8; static ID i_to_s, i_to_json, i_new, i_pack, i_unpack, i_create_id, i_extend, i_encode; -static ID sym_indent, sym_space, sym_space_before, sym_object_nl, sym_array_nl, sym_max_nesting, sym_allow_nan, - sym_ascii_only, sym_depth, sym_buffer_initial_length, sym_script_safe, sym_escape_slash, sym_strict; +static VALUE sym_indent, sym_space, sym_space_before, sym_object_nl, sym_array_nl, sym_max_nesting, sym_allow_nan, + sym_ascii_only, sym_depth, sym_buffer_initial_length, sym_script_safe, sym_escape_slash, sym_strict, sym_as_json; #define GET_STATE_TO(self, state) \ @@ -674,6 +675,7 @@ static void State_mark(void *ptr) rb_gc_mark_movable(state->space_before); rb_gc_mark_movable(state->object_nl); rb_gc_mark_movable(state->array_nl); + rb_gc_mark_movable(state->as_json); } static void State_compact(void *ptr) @@ -684,6 +686,7 @@ static void State_compact(void *ptr) state->space_before = rb_gc_location(state->space_before); state->object_nl = rb_gc_location(state->object_nl); state->array_nl = rb_gc_location(state->array_nl); + state->as_json = rb_gc_location(state->as_json); } static void State_free(void *ptr) @@ -740,6 +743,7 @@ static void vstate_spill(struct generate_json_data *data) RB_OBJ_WRITTEN(vstate, Qundef, state->space_before); RB_OBJ_WRITTEN(vstate, Qundef, state->object_nl); RB_OBJ_WRITTEN(vstate, Qundef, state->array_nl); + RB_OBJ_WRITTEN(vstate, Qundef, state->as_json); } static inline VALUE vstate_get(struct generate_json_data *data) @@ -1003,6 +1007,8 @@ static void generate_json_float(FBuffer *buffer, struct generate_json_data *data static void generate_json(FBuffer *buffer, struct generate_json_data *data, JSON_Generator_State *state, VALUE obj) { VALUE tmp; + bool as_json_called = false; +start: if (obj == Qnil) { generate_json_null(buffer, data, state, obj); } else if (obj == Qfalse) { @@ -1042,7 +1048,13 @@ static void generate_json(FBuffer *buffer, struct generate_json_data *data, JSON default: general: if (state->strict) { - raise_generator_error(obj, "%"PRIsVALUE" not allowed in JSON", CLASS_OF(obj)); + if (RTEST(state->as_json) && !as_json_called) { + obj = rb_proc_call_with_block(state->as_json, 1, &obj, Qnil); + as_json_called = true; + goto start; + } else { + raise_generator_error(obj, "%"PRIsVALUE" not allowed in JSON", CLASS_OF(obj)); + } } else if (rb_respond_to(obj, i_to_json)) { tmp = rb_funcall(obj, i_to_json, 1, vstate_get(data)); Check_Type(tmp, T_STRING); @@ -1132,6 +1144,7 @@ static VALUE cState_init_copy(VALUE obj, VALUE orig) objState->space_before = origState->space_before; objState->object_nl = origState->object_nl; objState->array_nl = origState->array_nl; + objState->as_json = origState->as_json; return obj; } @@ -1504,6 +1517,7 @@ static int configure_state_i(VALUE key, VALUE val, VALUE _arg) else if (key == sym_script_safe) { state->script_safe = RTEST(val); } else if (key == sym_escape_slash) { state->script_safe = RTEST(val); } else if (key == sym_strict) { state->strict = RTEST(val); } + else if (key == sym_as_json) { state->as_json = rb_convert_type(val, T_DATA, "Proc", "to_proc"); } return ST_CONTINUE; } @@ -1682,6 +1696,7 @@ void Init_generator(void) sym_script_safe = ID2SYM(rb_intern("script_safe")); sym_escape_slash = ID2SYM(rb_intern("escape_slash")); sym_strict = ID2SYM(rb_intern("strict")); + sym_as_json = ID2SYM(rb_intern("as_json")); usascii_encindex = rb_usascii_encindex(); utf8_encindex = rb_utf8_encindex(); diff --git a/lib/json/common.rb b/lib/json/common.rb index 197ae11f..c5dbd8db 100644 --- a/lib/json/common.rb +++ b/lib/json/common.rb @@ -841,6 +841,18 @@ def merge_dump_options(opts, strict: NOT_SET) class << self private :merge_dump_options end + + class Coder + def initialize(options = nil, &as_json) + default_options = { strict: true, as_json: as_json } + + @options = options ? options.merge(default_options) : default_options + end + + def dump(object) + State.generate(object, @options, nil) + end + end end module ::Kernel diff --git a/test/json/json_generator_test.rb b/test/json/json_generator_test.rb index 8dd3913d..c7bd2609 100755 --- a/test/json/json_generator_test.rb +++ b/test/json/json_generator_test.rb @@ -661,4 +661,31 @@ def test_string_ext_included_calls_super def test_nonutf8_encoding assert_equal("\"5\u{b0}\"", "5\xb0".dup.force_encoding(Encoding::ISO_8859_1).to_json) end + + def test_json_coder_with_proc + coder = JSON::Coder.new do |object| + "[Object object]" + end + assert_equal %(["[Object object]"]), coder.dump([Object.new]) + end + + def test_json_coder_with_proc_with_unsupported_value + coder = JSON::Coder.new do |object| + Object.new + end + assert_raise(JSON::GeneratorError) { coder.dump([Object.new]) } + end + + def test_json_coder_options + coder = JSON::Coder.new(array_nl: "\n") do |object| + 42 + end + + assert_equal "[\n42\n]", coder.dump([Object.new]) + end + + def test_json_generate_as_json_convert_to_proc + object = Object.new + assert_equal object.object_id.to_json, JSON.generate(object, strict: true, as_json: :object_id) + end end