aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Include/internal/pycore_unionobject.h1
-rw-r--r--Lib/test/test_types.py22
-rw-r--r--Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst5
-rw-r--r--Objects/genericaliasobject.c6
-rw-r--r--Objects/typeobject.c14
-rw-r--r--Objects/unionobject.c20
6 files changed, 51 insertions, 17 deletions
diff --git a/Include/internal/pycore_unionobject.h b/Include/internal/pycore_unionobject.h
index fa8ba6ed944..4d82b6fbeae 100644
--- a/Include/internal/pycore_unionobject.h
+++ b/Include/internal/pycore_unionobject.h
@@ -10,6 +10,7 @@ extern "C" {
PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args);
PyAPI_DATA(PyTypeObject) _Py_UnionType;
+PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param);
#ifdef __cplusplus
}
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 75c5eee42dc..3058a02d6ee 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -713,6 +713,28 @@ class TypesTests(unittest.TestCase):
assert repr(int | None) == "int | None"
assert repr(int | typing.GenericAlias(list, int)) == "int | list[int]"
+ def test_or_type_operator_with_genericalias(self):
+ a = list[int]
+ b = list[str]
+ c = dict[float, str]
+ # equivalence with typing.Union
+ self.assertEqual(a | b | c, typing.Union[a, b, c])
+ # de-duplicate
+ self.assertEqual(a | c | b | b | a | c, a | b | c)
+ # order shouldn't matter
+ self.assertEqual(a | b, b | a)
+ self.assertEqual(repr(a | b | c),
+ "list[int] | list[str] | dict[float, str]")
+
+ class BadType(type):
+ def __eq__(self, other):
+ return 1 / 0
+
+ bt = BadType('bt', (), {})
+ # Comparison should fail and errors should propagate out for bad types.
+ with self.assertRaises(ZeroDivisionError):
+ list[int] | list[bt]
+
def test_ellipsis_type(self):
self.assertIsInstance(Ellipsis, types.EllipsisType)
diff --git a/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst b/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst
new file mode 100644
index 00000000000..499bb324fb9
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2020-11-01-23-34-56.bpo-42233.zOSzja.rst
@@ -0,0 +1,5 @@
+Allow ``GenericAlias`` objects to use :ref:`union type expressions <types-union>`.
+This allows expressions like ``list[int] | dict[float, str]`` where previously a
+``TypeError`` would have been thrown. This also fixes union type expressions
+not de-duplicating ``GenericAlias`` objects. (Contributed by Ken Jin in
+:issue:`42233`.)
diff --git a/Objects/genericaliasobject.c b/Objects/genericaliasobject.c
index 6508c69cbf7..28ea487a44f 100644
--- a/Objects/genericaliasobject.c
+++ b/Objects/genericaliasobject.c
@@ -2,6 +2,7 @@
#include "Python.h"
#include "pycore_object.h"
+#include "pycore_unionobject.h" // _Py_union_as_number
#include "structmember.h" // PyMemberDef
typedef struct {
@@ -573,6 +574,10 @@ ga_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return Py_GenericAlias(origin, arguments);
}
+static PyNumberMethods ga_as_number = {
+ .nb_or = (binaryfunc)_Py_union_type_or, // Add __or__ function
+};
+
// TODO:
// - argument clinic?
// - __doc__?
@@ -586,6 +591,7 @@ PyTypeObject Py_GenericAliasType = {
.tp_basicsize = sizeof(gaobject),
.tp_dealloc = ga_dealloc,
.tp_repr = ga_repr,
+ .tp_as_number = &ga_as_number, // allow X | Y of GenericAlias objs
.tp_as_mapping = &ga_as_mapping,
.tp_hash = ga_hash,
.tp_call = ga_call,
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 3822b8cf813..55bf9b3f389 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -6,7 +6,7 @@
#include "pycore_object.h"
#include "pycore_pyerrors.h"
#include "pycore_pystate.h" // _PyThreadState_GET()
-#include "pycore_unionobject.h" // _Py_Union()
+#include "pycore_unionobject.h" // _Py_Union(), _Py_union_type_or
#include "frameobject.h"
#include "structmember.h" // PyMemberDef
@@ -3789,19 +3789,9 @@ type_is_gc(PyTypeObject *type)
return type->tp_flags & Py_TPFLAGS_HEAPTYPE;
}
-static PyObject *
-type_or(PyTypeObject* self, PyObject* param) {
- PyObject *tuple = PyTuple_Pack(2, self, param);
- if (tuple == NULL) {
- return NULL;
- }
- PyObject *new_union = _Py_Union(tuple);
- Py_DECREF(tuple);
- return new_union;
-}
static PyNumberMethods type_as_number = {
- .nb_or = (binaryfunc)type_or, // Add __or__ function
+ .nb_or = _Py_union_type_or, // Add __or__ function
};
PyTypeObject PyType_Type = {
diff --git a/Objects/unionobject.c b/Objects/unionobject.c
index 1b7f8ab51a4..2308bfc9f2a 100644
--- a/Objects/unionobject.c
+++ b/Objects/unionobject.c
@@ -237,9 +237,19 @@ dedup_and_flatten_args(PyObject* args)
PyObject* i_element = PyTuple_GET_ITEM(args, i);
for (Py_ssize_t j = i + 1; j < arg_length; j++) {
PyObject* j_element = PyTuple_GET_ITEM(args, j);
- if (i_element == j_element) {
- is_duplicate = 1;
+ int is_ga = Py_TYPE(i_element) == &Py_GenericAliasType &&
+ Py_TYPE(j_element) == &Py_GenericAliasType;
+ // RichCompare to also deduplicate GenericAlias types (slower)
+ is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ)
+ : i_element == j_element;
+ // Should only happen if RichCompare fails
+ if (is_duplicate < 0) {
+ Py_DECREF(args);
+ Py_DECREF(new_args);
+ return NULL;
}
+ if (is_duplicate)
+ break;
}
if (!is_duplicate) {
Py_INCREF(i_element);
@@ -290,8 +300,8 @@ is_unionable(PyObject *obj)
type == &_Py_UnionType);
}
-static PyObject *
-type_or(PyTypeObject* self, PyObject* param)
+PyObject *
+_Py_union_type_or(PyObject* self, PyObject* param)
{
PyObject *tuple = PyTuple_Pack(2, self, param);
if (tuple == NULL) {
@@ -404,7 +414,7 @@ static PyMethodDef union_methods[] = {
{0}};
static PyNumberMethods union_as_number = {
- .nb_or = (binaryfunc)type_or, // Add __or__ function
+ .nb_or = _Py_union_type_or, // Add __or__ function
};
PyTypeObject _Py_UnionType = {