From 6757ddadf9bb49dec0d632cd801ec8de70bac109 Mon Sep 17 00:00:00 2001
From: Jason Numeroff <jnumeroff@hotmail.com>
Date: Sun, 18 Aug 2024 10:52:32 -0400
Subject: [PATCH] Optionally allow overriding the policy class for permitted
 params

---
 README.md                   | 17 +++++++++++++++++
 lib/pundit/authorization.rb | 12 ++++--------
 lib/pundit/context.rb       | 22 ++++++++++++++++++++++
 spec/authorization_spec.rb  | 15 +++++++++++++++
 spec/spec_helper.rb         |  4 ++++
 5 files changed, 62 insertions(+), 8 deletions(-)

diff --git a/README.md b/README.md
index a1b59868..dc920b82 100644
--- a/README.md
+++ b/README.md
@@ -757,6 +757,23 @@ def pundit_params_for(_record)
 end
 ```
 
+You can pass an argument to override the policy class if necessary. For example:
+
+```ruby
+# app/controllers/posts_controller.rb
+class PostsController < ApplicationController
+  def update
+    @post = Post.find(params[:id])
+    if @post.update(permitted_attributes(@post), policy_class: PostPolicy)
+      redirect_to @post
+    else
+      render :edit
+    end
+  end
+end
+
+```
+
 ## RSpec
 
 ### Policy Specs
diff --git a/lib/pundit/authorization.rb b/lib/pundit/authorization.rb
index bc4cc4f4..ea6583f1 100644
--- a/lib/pundit/authorization.rb
+++ b/lib/pundit/authorization.rb
@@ -121,15 +121,11 @@ def policy(record)
     # @param record [Object] the object we're retrieving permitted attributes for
     # @param action [Symbol, String] the name of the action being performed on the record (e.g. `:update`).
     #   If omitted then this defaults to the Rails controller action name.
+    # @param policy_class [Class] the policy class we want to force use of
     # @return [Hash{String => Object}] the permitted attributes
-    def permitted_attributes(record, action = action_name)
-      policy = policy(record)
-      method_name = if policy.respond_to?("permitted_attributes_for_#{action}")
-        "permitted_attributes_for_#{action}"
-      else
-        "permitted_attributes"
-      end
-      pundit_params_for(record).permit(*policy.public_send(method_name))
+    def permitted_attributes(record, action = action_name, policy_class: nil)
+      required_params = pundit_params_for(record)
+      pundit.permitted_attributes(record, action: action, required_params: required_params, policy_class: policy_class)
     end
 
     # Retrieves the params for the given record.
diff --git a/lib/pundit/context.rb b/lib/pundit/context.rb
index a5f86716..24a6ea21 100644
--- a/lib/pundit/context.rb
+++ b/lib/pundit/context.rb
@@ -99,6 +99,28 @@ def policy!(record)
       cached_find(record, &:policy!)
     end
 
+    # Retrieves a set of permitted attributes from the policy by instantiating
+    # the policy class for the given record and calling `permitted_attributes` on
+    # it, or `permitted_attributes_for_{action}` if `action` is defined. It then infers
+    # what key the record should have in the params hash and retrieves the
+    # permitted attributes from the params hash under that key.
+    #
+    # @see https://github.com/varvet/pundit#strong-parameters
+    # @param record [Object] the object we're retrieving permitted attributes for
+    # @param action [Symbol, String] the name of the action being performed on the record (e.g. `:update`).
+    # @param required_params [ActionController::Parameters] the params
+    # @param policy_class [Class] the policy class we want to force use of
+    # @return [Hash{String => Object}] the permitted attributes
+    def permitted_attributes(record, action:, required_params:, policy_class: nil)
+      policy = policy_class ? policy_class.new(user, record) : policy(record)
+      method_name = if policy.respond_to?("permitted_attributes_for_#{action}")
+        "permitted_attributes_for_#{action}"
+      else
+        "permitted_attributes"
+      end
+      required_params.permit(*policy.public_send(method_name))
+    end
+
     private
 
     def cached_find(record)
diff --git a/spec/authorization_spec.rb b/spec/authorization_spec.rb
index 8bfa3fcb..a08b6ad6 100644
--- a/spec/authorization_spec.rb
+++ b/spec/authorization_spec.rb
@@ -208,6 +208,21 @@ def to_params(*args, **kwargs, &block)
       expect(Controller.new(double, action, params).permitted_attributes(post).to_h).to eq("votes" => 5)
     end
 
+    it "checks different policy for permitted attributes" do
+      params = to_params(
+        post: {
+          title: "Hello",
+          votes: 5
+        }
+      )
+
+      action = "update"
+
+      expect(Controller.new(user, action, params).permitted_attributes(post, policy_class: PublicationPolicy).to_h).to eq(
+        "title" => "Hello"
+      )
+    end
+
     it "checks policy for permitted attributes for record of a ActiveModel type" do
       customer_post = Customer::Post.new(user)
       params = to_params(
diff --git a/spec/spec_helper.rb b/spec/spec_helper.rb
index ff70ac0d..2369f08a 100644
--- a/spec/spec_helper.rb
+++ b/spec/spec_helper.rb
@@ -161,6 +161,10 @@ def resolve
   def create?
     true
   end
+
+  def permitted_attributes
+    [:title]
+  end
 end
 
 class Comment